+ {% if error is defined %}
+
Error: {{ error }}
+ {% endif %}
Please click the button below if you agree to the
privacy policy of this homeserver.
diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py
index d5862a4da436..b2514d9d0df4 100644
--- a/synapse/rest/admin/__init__.py
+++ b/synapse/rest/admin/__init__.py
@@ -36,7 +36,11 @@
)
from synapse.rest.admin.groups import DeleteGroupAdminRestServlet
from synapse.rest.admin.media import ListMediaInRoom, register_servlets_for_media_repo
-from synapse.rest.admin.purge_room_servlet import PurgeRoomServlet
+from synapse.rest.admin.registration_tokens import (
+ ListRegistrationTokensRestServlet,
+ NewRegistrationTokenRestServlet,
+ RegistrationTokenRestServlet,
+)
from synapse.rest.admin.rooms import (
DeleteRoomRestServlet,
ForwardExtremitiesRestServlet,
@@ -47,7 +51,6 @@
RoomMembersRestServlet,
RoomRestServlet,
RoomStateRestServlet,
- ShutdownRoomRestServlet,
)
from synapse.rest.admin.server_notice_servlet import SendServerNoticeServlet
from synapse.rest.admin.statistics import UserMediaStatisticsRestServlet
@@ -220,8 +223,6 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
RoomMembersRestServlet(hs).register(http_server)
DeleteRoomRestServlet(hs).register(http_server)
JoinRoomAliasServlet(hs).register(http_server)
- PurgeRoomServlet(hs).register(http_server)
- SendServerNoticeServlet(hs).register(http_server)
VersionServlet(hs).register(http_server)
UserAdminServlet(hs).register(http_server)
UserMembershipRestServlet(hs).register(http_server)
@@ -241,6 +242,13 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
RoomEventContextServlet(hs).register(http_server)
RateLimitRestServlet(hs).register(http_server)
UsernameAvailableRestServlet(hs).register(http_server)
+ ListRegistrationTokensRestServlet(hs).register(http_server)
+ NewRegistrationTokenRestServlet(hs).register(http_server)
+ RegistrationTokenRestServlet(hs).register(http_server)
+
+ # Some servlets only get registered for the main process.
+ if hs.config.worker_app is None:
+ SendServerNoticeServlet(hs).register(http_server)
def register_servlets_for_client_rest_resource(
@@ -253,7 +261,6 @@ def register_servlets_for_client_rest_resource(
PurgeHistoryRestServlet(hs).register(http_server)
ResetPasswordRestServlet(hs).register(http_server)
SearchUsersRestServlet(hs).register(http_server)
- ShutdownRoomRestServlet(hs).register(http_server)
UserRegisterServlet(hs).register(http_server)
DeleteGroupAdminRestServlet(hs).register(http_server)
AccountValidityRenewServlet(hs).register(http_server)
diff --git a/synapse/rest/admin/purge_room_servlet.py b/synapse/rest/admin/purge_room_servlet.py
deleted file mode 100644
index 2365ff7a0f26..000000000000
--- a/synapse/rest/admin/purge_room_servlet.py
+++ /dev/null
@@ -1,58 +0,0 @@
-# Copyright 2019 The Matrix.org Foundation C.I.C.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-from typing import TYPE_CHECKING, Tuple
-
-from synapse.http.servlet import (
- RestServlet,
- assert_params_in_dict,
- parse_json_object_from_request,
-)
-from synapse.http.site import SynapseRequest
-from synapse.rest.admin import assert_requester_is_admin
-from synapse.rest.admin._base import admin_patterns
-from synapse.types import JsonDict
-
-if TYPE_CHECKING:
- from synapse.server import HomeServer
-
-
-class PurgeRoomServlet(RestServlet):
- """Servlet which will remove all trace of a room from the database
-
- POST /_synapse/admin/v1/purge_room
- {
- "room_id": "!room:id"
- }
-
- returns:
-
- {}
- """
-
- PATTERNS = admin_patterns("/purge_room$")
-
- def __init__(self, hs: "HomeServer"):
- self.hs = hs
- self.auth = hs.get_auth()
- self.pagination_handler = hs.get_pagination_handler()
-
- async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
- await assert_requester_is_admin(self.auth, request)
-
- body = parse_json_object_from_request(request)
- assert_params_in_dict(body, ("room_id",))
-
- await self.pagination_handler.purge_room(body["room_id"])
-
- return 200, {}
diff --git a/synapse/rest/admin/registration_tokens.py b/synapse/rest/admin/registration_tokens.py
new file mode 100644
index 000000000000..5a1c929d85c4
--- /dev/null
+++ b/synapse/rest/admin/registration_tokens.py
@@ -0,0 +1,321 @@
+# Copyright 2021 Callum Brown
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import string
+from typing import TYPE_CHECKING, Tuple
+
+from synapse.api.errors import Codes, NotFoundError, SynapseError
+from synapse.http.servlet import (
+ RestServlet,
+ parse_boolean,
+ parse_json_object_from_request,
+)
+from synapse.http.site import SynapseRequest
+from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin
+from synapse.types import JsonDict
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
+logger = logging.getLogger(__name__)
+
+
+class ListRegistrationTokensRestServlet(RestServlet):
+ """List registration tokens.
+
+ To list all tokens:
+
+ GET /_synapse/admin/v1/registration_tokens
+
+ 200 OK
+
+ {
+ "registration_tokens": [
+ {
+ "token": "abcd",
+ "uses_allowed": 3,
+ "pending": 0,
+ "completed": 1,
+ "expiry_time": null
+ },
+ {
+ "token": "wxyz",
+ "uses_allowed": null,
+ "pending": 0,
+ "completed": 9,
+ "expiry_time": 1625394937000
+ }
+ ]
+ }
+
+ The optional query parameter `valid` can be used to filter the response.
+ If it is `true`, only valid tokens are returned. If it is `false`, only
+ tokens that have expired or have had all uses exhausted are returned.
+ If it is omitted, all tokens are returned regardless of validity.
+ """
+
+ PATTERNS = admin_patterns("/registration_tokens$")
+
+ def __init__(self, hs: "HomeServer"):
+ self.hs = hs
+ self.auth = hs.get_auth()
+ self.store = hs.get_datastore()
+
+ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
+ await assert_requester_is_admin(self.auth, request)
+ valid = parse_boolean(request, "valid")
+ token_list = await self.store.get_registration_tokens(valid)
+ return 200, {"registration_tokens": token_list}
+
+
+class NewRegistrationTokenRestServlet(RestServlet):
+ """Create a new registration token.
+
+ For example, to create a token specifying some fields:
+
+ POST /_synapse/admin/v1/registration_tokens/new
+
+ {
+ "token": "defg",
+ "uses_allowed": 1
+ }
+
+ 200 OK
+
+ {
+ "token": "defg",
+ "uses_allowed": 1,
+ "pending": 0,
+ "completed": 0,
+ "expiry_time": null
+ }
+
+ Defaults are used for any fields not specified.
+ """
+
+ PATTERNS = admin_patterns("/registration_tokens/new$")
+
+ def __init__(self, hs: "HomeServer"):
+ self.hs = hs
+ self.auth = hs.get_auth()
+ self.store = hs.get_datastore()
+ self.clock = hs.get_clock()
+ # A string of all the characters allowed to be in a registration_token
+ self.allowed_chars = string.ascii_letters + string.digits + "-_"
+ self.allowed_chars_set = set(self.allowed_chars)
+
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
+ await assert_requester_is_admin(self.auth, request)
+ body = parse_json_object_from_request(request)
+
+ if "token" in body:
+ token = body["token"]
+ if not isinstance(token, str):
+ raise SynapseError(400, "token must be a string", Codes.INVALID_PARAM)
+ if not (0 < len(token) <= 64):
+ raise SynapseError(
+ 400,
+ "token must not be empty and must not be longer than 64 characters",
+ Codes.INVALID_PARAM,
+ )
+ if not set(token).issubset(self.allowed_chars_set):
+ raise SynapseError(
+ 400,
+ "token must consist only of characters matched by the regex [A-Za-z0-9-_]",
+ Codes.INVALID_PARAM,
+ )
+
+ else:
+ # Get length of token to generate (default is 16)
+ length = body.get("length", 16)
+ if not isinstance(length, int):
+ raise SynapseError(
+ 400, "length must be an integer", Codes.INVALID_PARAM
+ )
+ if not (0 < length <= 64):
+ raise SynapseError(
+ 400,
+ "length must be greater than zero and not greater than 64",
+ Codes.INVALID_PARAM,
+ )
+
+ # Generate token
+ token = await self.store.generate_registration_token(
+ length, self.allowed_chars
+ )
+
+ uses_allowed = body.get("uses_allowed", None)
+ if not (
+ uses_allowed is None
+ or (isinstance(uses_allowed, int) and uses_allowed >= 0)
+ ):
+ raise SynapseError(
+ 400,
+ "uses_allowed must be a non-negative integer or null",
+ Codes.INVALID_PARAM,
+ )
+
+ expiry_time = body.get("expiry_time", None)
+ if not isinstance(expiry_time, (int, type(None))):
+ raise SynapseError(
+ 400, "expiry_time must be an integer or null", Codes.INVALID_PARAM
+ )
+ if isinstance(expiry_time, int) and expiry_time < self.clock.time_msec():
+ raise SynapseError(
+ 400, "expiry_time must not be in the past", Codes.INVALID_PARAM
+ )
+
+ created = await self.store.create_registration_token(
+ token, uses_allowed, expiry_time
+ )
+ if not created:
+ raise SynapseError(
+ 400, f"Token already exists: {token}", Codes.INVALID_PARAM
+ )
+
+ resp = {
+ "token": token,
+ "uses_allowed": uses_allowed,
+ "pending": 0,
+ "completed": 0,
+ "expiry_time": expiry_time,
+ }
+ return 200, resp
+
+
+class RegistrationTokenRestServlet(RestServlet):
+ """Retrieve, update, or delete the given token.
+
+ For example,
+
+ to retrieve a token:
+
+ GET /_synapse/admin/v1/registration_tokens/abcd
+
+ 200 OK
+
+ {
+ "token": "abcd",
+ "uses_allowed": 3,
+ "pending": 0,
+ "completed": 1,
+ "expiry_time": null
+ }
+
+
+ to update a token:
+
+ PUT /_synapse/admin/v1/registration_tokens/defg
+
+ {
+ "uses_allowed": 5,
+ "expiry_time": 4781243146000
+ }
+
+ 200 OK
+
+ {
+ "token": "defg",
+ "uses_allowed": 5,
+ "pending": 0,
+ "completed": 0,
+ "expiry_time": 4781243146000
+ }
+
+
+ to delete a token:
+
+ DELETE /_synapse/admin/v1/registration_tokens/wxyz
+
+ 200 OK
+
+ {}
+ """
+
+ PATTERNS = admin_patterns("/registration_tokens/(?P[^/]*)$")
+
+ def __init__(self, hs: "HomeServer"):
+ self.hs = hs
+ self.clock = hs.get_clock()
+ self.auth = hs.get_auth()
+ self.store = hs.get_datastore()
+
+ async def on_GET(self, request: SynapseRequest, token: str) -> Tuple[int, JsonDict]:
+ """Retrieve a registration token."""
+ await assert_requester_is_admin(self.auth, request)
+ token_info = await self.store.get_one_registration_token(token)
+
+ # If no result return a 404
+ if token_info is None:
+ raise NotFoundError(f"No such registration token: {token}")
+
+ return 200, token_info
+
+ async def on_PUT(self, request: SynapseRequest, token: str) -> Tuple[int, JsonDict]:
+ """Update a registration token."""
+ await assert_requester_is_admin(self.auth, request)
+ body = parse_json_object_from_request(request)
+ new_attributes = {}
+
+ # Only add uses_allowed to new_attributes if it is present and valid
+ if "uses_allowed" in body:
+ uses_allowed = body["uses_allowed"]
+ if not (
+ uses_allowed is None
+ or (isinstance(uses_allowed, int) and uses_allowed >= 0)
+ ):
+ raise SynapseError(
+ 400,
+ "uses_allowed must be a non-negative integer or null",
+ Codes.INVALID_PARAM,
+ )
+ new_attributes["uses_allowed"] = uses_allowed
+
+ if "expiry_time" in body:
+ expiry_time = body["expiry_time"]
+ if not isinstance(expiry_time, (int, type(None))):
+ raise SynapseError(
+ 400, "expiry_time must be an integer or null", Codes.INVALID_PARAM
+ )
+ if isinstance(expiry_time, int) and expiry_time < self.clock.time_msec():
+ raise SynapseError(
+ 400, "expiry_time must not be in the past", Codes.INVALID_PARAM
+ )
+ new_attributes["expiry_time"] = expiry_time
+
+ if len(new_attributes) == 0:
+ # Nothing to update, get token info to return
+ token_info = await self.store.get_one_registration_token(token)
+ else:
+ token_info = await self.store.update_registration_token(
+ token, new_attributes
+ )
+
+ # If no result return a 404
+ if token_info is None:
+ raise NotFoundError(f"No such registration token: {token}")
+
+ return 200, token_info
+
+ async def on_DELETE(
+ self, request: SynapseRequest, token: str
+ ) -> Tuple[int, JsonDict]:
+ """Delete a registration token."""
+ await assert_requester_is_admin(self.auth, request)
+
+ if await self.store.delete_registration_token(token):
+ return 200, {}
+
+ raise NotFoundError(f"No such registration token: {token}")
diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py
index 975c28b2258e..ad83d4b54c64 100644
--- a/synapse/rest/admin/rooms.py
+++ b/synapse/rest/admin/rooms.py
@@ -46,41 +46,6 @@
logger = logging.getLogger(__name__)
-class ShutdownRoomRestServlet(RestServlet):
- """Shuts down a room by removing all local users from the room and blocking
- all future invites and joins to the room. Any local aliases will be repointed
- to a new room created by `new_room_user_id` and kicked users will be auto
- joined to the new room.
- """
-
- PATTERNS = admin_patterns("/shutdown_room/(?P[^/]+)")
-
- def __init__(self, hs: "HomeServer"):
- self.hs = hs
- self.auth = hs.get_auth()
- self.room_shutdown_handler = hs.get_room_shutdown_handler()
-
- async def on_POST(
- self, request: SynapseRequest, room_id: str
- ) -> Tuple[int, JsonDict]:
- requester = await self.auth.get_user_by_req(request)
- await assert_user_is_admin(self.auth, requester.user)
-
- content = parse_json_object_from_request(request)
- assert_params_in_dict(content, ["new_room_user_id"])
-
- ret = await self.room_shutdown_handler.shutdown_room(
- room_id=room_id,
- new_room_user_id=content["new_room_user_id"],
- new_room_name=content.get("room_name"),
- message=content.get("message"),
- requester_user_id=requester.user.to_string(),
- block=True,
- )
-
- return (200, ret)
-
-
class DeleteRoomRestServlet(RestServlet):
"""Delete a room from server.
diff --git a/synapse/rest/admin/server_notice_servlet.py b/synapse/rest/admin/server_notice_servlet.py
index b5e4c474efc8..42201afc86d3 100644
--- a/synapse/rest/admin/server_notice_servlet.py
+++ b/synapse/rest/admin/server_notice_servlet.py
@@ -14,7 +14,7 @@
from typing import TYPE_CHECKING, Optional, Tuple
from synapse.api.constants import EventTypes
-from synapse.api.errors import SynapseError
+from synapse.api.errors import NotFoundError, SynapseError
from synapse.http.server import HttpServer
from synapse.http.servlet import (
RestServlet,
@@ -53,6 +53,8 @@ class SendServerNoticeServlet(RestServlet):
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.auth = hs.get_auth()
+ self.server_notices_manager = hs.get_server_notices_manager()
+ self.admin_handler = hs.get_admin_handler()
self.txns = HttpTransactionCache(hs)
def register(self, json_resource: HttpServer):
@@ -79,19 +81,22 @@ async def on_POST(
# We grab the server notices manager here as its initialisation has a check for worker processes,
# but worker processes still need to initialise SendServerNoticeServlet (as it is part of the
# admin api).
- if not self.hs.get_server_notices_manager().is_enabled():
+ if not self.server_notices_manager.is_enabled():
raise SynapseError(400, "Server notices are not enabled on this server")
- user_id = body["user_id"]
- UserID.from_string(user_id)
- if not self.hs.is_mine_id(user_id):
+ target_user = UserID.from_string(body["user_id"])
+ if not self.hs.is_mine(target_user):
raise SynapseError(400, "Server notices can only be sent to local users")
- event = await self.hs.get_server_notices_manager().send_notice(
- user_id=body["user_id"],
+ if not await self.admin_handler.get_user(target_user):
+ raise NotFoundError("User not found")
+
+ event = await self.server_notices_manager.send_notice(
+ user_id=target_user.to_string(),
type=event_type,
state_key=state_key,
event_content=body["content"],
+ txn_id=txn_id,
)
return 200, {"event_id": event.event_id}
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index 5429e945e302..87a59872644c 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -229,13 +229,18 @@ async def on_PUT(
if not isinstance(deactivate, bool):
raise SynapseError(400, "'deactivated' parameter is not of type boolean")
- # convert into List[Tuple[str, str]]
+ # convert List[Dict[str, str]] into Set[Tuple[str, str]]
if external_ids is not None:
- new_external_ids = []
- for external_id in external_ids:
- new_external_ids.append(
- (external_id["auth_provider"], external_id["external_id"])
- )
+ new_external_ids = {
+ (external_id["auth_provider"], external_id["external_id"])
+ for external_id in external_ids
+ }
+
+ # convert List[Dict[str, str]] into Set[Tuple[str, str]]
+ if threepids is not None:
+ new_threepids = {
+ (threepid["medium"], threepid["address"]) for threepid in threepids
+ }
if user: # modify user
if "displayname" in body:
@@ -244,29 +249,39 @@ async def on_PUT(
)
if threepids is not None:
- # remove old threepids from user
- old_threepids = await self.store.user_get_threepids(user_id)
- for threepid in old_threepids:
+ # get changed threepids (added and removed)
+ # convert List[Dict[str, Any]] into Set[Tuple[str, str]]
+ cur_threepids = {
+ (threepid["medium"], threepid["address"])
+ for threepid in await self.store.user_get_threepids(user_id)
+ }
+ add_threepids = new_threepids - cur_threepids
+ del_threepids = cur_threepids - new_threepids
+
+ # remove old threepids
+ for medium, address in del_threepids:
try:
await self.auth_handler.delete_threepid(
- user_id, threepid["medium"], threepid["address"], None
+ user_id, medium, address, None
)
except Exception:
logger.exception("Failed to remove threepids")
raise SynapseError(500, "Failed to remove threepids")
- # add new threepids to user
+ # add new threepids
current_time = self.hs.get_clock().time_msec()
- for threepid in threepids:
+ for medium, address in add_threepids:
await self.auth_handler.add_threepid(
- user_id, threepid["medium"], threepid["address"], current_time
+ user_id, medium, address, current_time
)
if external_ids is not None:
# get changed external_ids (added and removed)
- cur_external_ids = await self.store.get_external_ids_by_user(user_id)
- add_external_ids = set(new_external_ids) - set(cur_external_ids)
- del_external_ids = set(cur_external_ids) - set(new_external_ids)
+ cur_external_ids = set(
+ await self.store.get_external_ids_by_user(user_id)
+ )
+ add_external_ids = new_external_ids - cur_external_ids
+ del_external_ids = cur_external_ids - new_external_ids
# remove old external_ids
for auth_provider, external_id in del_external_ids:
@@ -349,9 +364,9 @@ async def on_PUT(
if threepids is not None:
current_time = self.hs.get_clock().time_msec()
- for threepid in threepids:
+ for medium, address in new_threepids:
await self.auth_handler.add_threepid(
- user_id, threepid["medium"], threepid["address"], current_time
+ user_id, medium, address, current_time
)
if (
self.hs.config.email_enable_notifs
@@ -363,8 +378,8 @@ async def on_PUT(
kind="email",
app_id="m.email",
app_display_name="Email Notifications",
- device_display_name=threepid["address"],
- pushkey=threepid["address"],
+ device_display_name=address,
+ pushkey=address,
lang=None, # We don't know a user's language here
data={},
)
diff --git a/synapse/rest/client/account_validity.py b/synapse/rest/client/account_validity.py
index 3ebe40186153..6c24b96c547d 100644
--- a/synapse/rest/client/account_validity.py
+++ b/synapse/rest/client/account_validity.py
@@ -13,24 +13,27 @@
# limitations under the License.
import logging
+from typing import TYPE_CHECKING, Tuple
-from synapse.api.errors import SynapseError
-from synapse.http.server import respond_with_html
-from synapse.http.servlet import RestServlet
+from twisted.web.server import Request
+
+from synapse.http.server import HttpServer, respond_with_html
+from synapse.http.servlet import RestServlet, parse_string
+from synapse.http.site import SynapseRequest
+from synapse.types import JsonDict
from ._base import client_patterns
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
class AccountValidityRenewServlet(RestServlet):
PATTERNS = client_patterns("/account_validity/renew$")
- def __init__(self, hs):
- """
- Args:
- hs (synapse.server.HomeServer): server
- """
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
@@ -46,18 +49,14 @@ def __init__(self, hs):
hs.config.account_validity.account_validity_invalid_token_template
)
- async def on_GET(self, request):
- if b"token" not in request.args:
- raise SynapseError(400, "Missing renewal token")
- renewal_token = request.args[b"token"][0]
+ async def on_GET(self, request: Request) -> None:
+ renewal_token = parse_string(request, "token", required=True)
(
token_valid,
token_stale,
expiration_ts,
- ) = await self.account_activity_handler.renew_account(
- renewal_token.decode("utf8")
- )
+ ) = await self.account_activity_handler.renew_account(renewal_token)
if token_valid:
status_code = 200
@@ -77,11 +76,7 @@ async def on_GET(self, request):
class AccountValiditySendMailServlet(RestServlet):
PATTERNS = client_patterns("/account_validity/send_mail$")
- def __init__(self, hs):
- """
- Args:
- hs (synapse.server.HomeServer): server
- """
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
@@ -91,7 +86,7 @@ def __init__(self, hs):
hs.config.account_validity.account_validity_renew_by_email_enabled
)
- async def on_POST(self, request):
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_expired=True)
user_id = requester.user.to_string()
await self.account_activity_handler.send_renewal_email_to_user(user_id)
@@ -99,6 +94,6 @@ async def on_POST(self, request):
return 200, {}
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
AccountValidityRenewServlet(hs).register(http_server)
AccountValiditySendMailServlet(hs).register(http_server)
diff --git a/synapse/rest/client/auth.py b/synapse/rest/client/auth.py
index 6ea1b50a625e..df8cc4ac7ab5 100644
--- a/synapse/rest/client/auth.py
+++ b/synapse/rest/client/auth.py
@@ -15,11 +15,14 @@
import logging
from typing import TYPE_CHECKING
+from twisted.web.server import Request
+
from synapse.api.constants import LoginType
-from synapse.api.errors import SynapseError
+from synapse.api.errors import LoginError, SynapseError
from synapse.api.urls import CLIENT_API_PREFIX
-from synapse.http.server import respond_with_html
+from synapse.http.server import HttpServer, respond_with_html
from synapse.http.servlet import RestServlet, parse_string
+from synapse.http.site import SynapseRequest
from ._base import client_patterns
@@ -46,9 +49,10 @@ def __init__(self, hs: "HomeServer"):
self.registration_handler = hs.get_registration_handler()
self.recaptcha_template = hs.config.recaptcha_template
self.terms_template = hs.config.terms_template
+ self.registration_token_template = hs.config.registration_token_template
self.success_template = hs.config.fallback_success_template
- async def on_GET(self, request, stagetype):
+ async def on_GET(self, request: SynapseRequest, stagetype: str) -> None:
session = parse_string(request, "session")
if not session:
raise SynapseError(400, "No session supplied")
@@ -74,6 +78,12 @@ async def on_GET(self, request, stagetype):
# re-authenticate with their SSO provider.
html = await self.auth_handler.start_sso_ui_auth(request, session)
+ elif stagetype == LoginType.REGISTRATION_TOKEN:
+ html = self.registration_token_template.render(
+ session=session,
+ myurl=f"{CLIENT_API_PREFIX}/r0/auth/{LoginType.REGISTRATION_TOKEN}/fallback/web",
+ )
+
else:
raise SynapseError(404, "Unknown auth stage type")
@@ -81,7 +91,7 @@ async def on_GET(self, request, stagetype):
respond_with_html(request, 200, html)
return None
- async def on_POST(self, request, stagetype):
+ async def on_POST(self, request: Request, stagetype: str) -> None:
session = parse_string(request, "session")
if not session:
@@ -95,29 +105,32 @@ async def on_POST(self, request, stagetype):
authdict = {"response": response, "session": session}
- success = await self.auth_handler.add_oob_auth(
- LoginType.RECAPTCHA, authdict, request.getClientIP()
- )
-
- if success:
- html = self.success_template.render()
- else:
+ try:
+ await self.auth_handler.add_oob_auth(
+ LoginType.RECAPTCHA, authdict, request.getClientIP()
+ )
+ except LoginError as e:
+ # Authentication failed, let user try again
html = self.recaptcha_template.render(
session=session,
myurl="%s/r0/auth/%s/fallback/web"
% (CLIENT_API_PREFIX, LoginType.RECAPTCHA),
sitekey=self.hs.config.recaptcha_public_key,
+ error=e.msg,
)
+ else:
+ # No LoginError was raised, so authentication was successful
+ html = self.success_template.render()
+
elif stagetype == LoginType.TERMS:
authdict = {"session": session}
- success = await self.auth_handler.add_oob_auth(
- LoginType.TERMS, authdict, request.getClientIP()
- )
-
- if success:
- html = self.success_template.render()
- else:
+ try:
+ await self.auth_handler.add_oob_auth(
+ LoginType.TERMS, authdict, request.getClientIP()
+ )
+ except LoginError as e:
+ # Authentication failed, let user try again
html = self.terms_template.render(
session=session,
terms_url="%s_matrix/consent?v=%s"
@@ -127,10 +140,33 @@ async def on_POST(self, request, stagetype):
),
myurl="%s/r0/auth/%s/fallback/web"
% (CLIENT_API_PREFIX, LoginType.TERMS),
+ error=e.msg,
)
+ else:
+ # No LoginError was raised, so authentication was successful
+ html = self.success_template.render()
+
elif stagetype == LoginType.SSO:
# The SSO fallback workflow should not post here,
raise SynapseError(404, "Fallback SSO auth does not support POST requests.")
+
+ elif stagetype == LoginType.REGISTRATION_TOKEN:
+ token = parse_string(request, "token", required=True)
+ authdict = {"session": session, "token": token}
+
+ try:
+ await self.auth_handler.add_oob_auth(
+ LoginType.REGISTRATION_TOKEN, authdict, request.getClientIP()
+ )
+ except LoginError as e:
+ html = self.registration_token_template.render(
+ session=session,
+ myurl=f"{CLIENT_API_PREFIX}/r0/auth/{LoginType.REGISTRATION_TOKEN}/fallback/web",
+ error=e.msg,
+ )
+ else:
+ html = self.success_template.render()
+
else:
raise SynapseError(404, "Unknown auth stage type")
@@ -139,5 +175,5 @@ async def on_POST(self, request, stagetype):
return None
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
AuthRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/capabilities.py b/synapse/rest/client/capabilities.py
index 88e3aac7974b..65b3b5ce2cc5 100644
--- a/synapse/rest/client/capabilities.py
+++ b/synapse/rest/client/capabilities.py
@@ -15,6 +15,7 @@
from typing import TYPE_CHECKING, Tuple
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, MSC3244_CAPABILITIES
+from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet
from synapse.http.site import SynapseRequest
from synapse.types import JsonDict
@@ -61,8 +62,19 @@ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
"org.matrix.msc3244.room_capabilities"
] = MSC3244_CAPABILITIES
+ if self.config.experimental.msc3283_enabled:
+ response["capabilities"]["org.matrix.msc3283.set_displayname"] = {
+ "enabled": self.config.enable_set_displayname
+ }
+ response["capabilities"]["org.matrix.msc3283.set_avatar_url"] = {
+ "enabled": self.config.enable_set_avatar_url
+ }
+ response["capabilities"]["org.matrix.msc3283.3pid_changes"] = {
+ "enabled": self.config.enable_3pid_changes
+ }
+
return 200, response
-def register_servlets(hs: "HomeServer", http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
CapabilitiesRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/devices.py b/synapse/rest/client/devices.py
index 8b9674db064a..25bc3c8f477b 100644
--- a/synapse/rest/client/devices.py
+++ b/synapse/rest/client/devices.py
@@ -14,34 +14,36 @@
# limitations under the License.
import logging
+from typing import TYPE_CHECKING, Tuple
from synapse.api import errors
+from synapse.http.server import HttpServer
from synapse.http.servlet import (
RestServlet,
assert_params_in_dict,
parse_json_object_from_request,
)
from synapse.http.site import SynapseRequest
+from synapse.types import JsonDict
from ._base import client_patterns, interactive_auth_handler
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
class DevicesRestServlet(RestServlet):
PATTERNS = client_patterns("/devices$")
- def __init__(self, hs):
- """
- Args:
- hs (synapse.server.HomeServer): server
- """
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler()
- async def on_GET(self, request):
+ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
devices = await self.device_handler.get_devices_by_user(
requester.user.to_string()
@@ -57,7 +59,7 @@ class DeleteDevicesRestServlet(RestServlet):
PATTERNS = client_patterns("/delete_devices")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.auth = hs.get_auth()
@@ -65,7 +67,7 @@ def __init__(self, hs):
self.auth_handler = hs.get_auth_handler()
@interactive_auth_handler
- async def on_POST(self, request):
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
try:
@@ -100,18 +102,16 @@ async def on_POST(self, request):
class DeviceRestServlet(RestServlet):
PATTERNS = client_patterns("/devices/(?P[^/]*)$")
- def __init__(self, hs):
- """
- Args:
- hs (synapse.server.HomeServer): server
- """
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler()
self.auth_handler = hs.get_auth_handler()
- async def on_GET(self, request, device_id):
+ async def on_GET(
+ self, request: SynapseRequest, device_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
device = await self.device_handler.get_device(
requester.user.to_string(), device_id
@@ -119,7 +119,9 @@ async def on_GET(self, request, device_id):
return 200, device
@interactive_auth_handler
- async def on_DELETE(self, request, device_id):
+ async def on_DELETE(
+ self, request: SynapseRequest, device_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
try:
@@ -146,7 +148,9 @@ async def on_DELETE(self, request, device_id):
await self.device_handler.delete_device(requester.user.to_string(), device_id)
return 200, {}
- async def on_PUT(self, request, device_id):
+ async def on_PUT(
+ self, request: SynapseRequest, device_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
body = parse_json_object_from_request(request)
@@ -193,13 +197,13 @@ class DehydratedDeviceServlet(RestServlet):
PATTERNS = client_patterns("/org.matrix.msc2697.v2/dehydrated_device", releases=())
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler()
- async def on_GET(self, request: SynapseRequest):
+ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
dehydrated_device = await self.device_handler.get_dehydrated_device(
requester.user.to_string()
@@ -211,7 +215,7 @@ async def on_GET(self, request: SynapseRequest):
else:
raise errors.NotFoundError("No dehydrated device available")
- async def on_PUT(self, request: SynapseRequest):
+ async def on_PUT(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
submission = parse_json_object_from_request(request)
requester = await self.auth.get_user_by_req(request)
@@ -259,13 +263,13 @@ class ClaimDehydratedDeviceServlet(RestServlet):
"/org.matrix.msc2697.v2/dehydrated_device/claim", releases=()
)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler()
- async def on_POST(self, request: SynapseRequest):
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
submission = parse_json_object_from_request(request)
@@ -292,7 +296,7 @@ async def on_POST(self, request: SynapseRequest):
return (200, result)
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
DeleteDevicesRestServlet(hs).register(http_server)
DevicesRestServlet(hs).register(http_server)
DeviceRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/directory.py b/synapse/rest/client/directory.py
index ffa075c8e5f6..ee247e3d1e0d 100644
--- a/synapse/rest/client/directory.py
+++ b/synapse/rest/client/directory.py
@@ -12,8 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
import logging
+from typing import TYPE_CHECKING, Tuple
+
+from twisted.web.server import Request
from synapse.api.errors import (
AuthError,
@@ -22,14 +24,19 @@
NotFoundError,
SynapseError,
)
+from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_json_object_from_request
+from synapse.http.site import SynapseRequest
from synapse.rest.client._base import client_patterns
-from synapse.types import RoomAlias
+from synapse.types import JsonDict, RoomAlias
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
logger = logging.getLogger(__name__)
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
ClientDirectoryServer(hs).register(http_server)
ClientDirectoryListServer(hs).register(http_server)
ClientAppserviceDirectoryListServer(hs).register(http_server)
@@ -38,21 +45,23 @@ def register_servlets(hs, http_server):
class ClientDirectoryServer(RestServlet):
PATTERNS = client_patterns("/directory/room/(?P[^/]*)$", v1=True)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.store = hs.get_datastore()
self.directory_handler = hs.get_directory_handler()
self.auth = hs.get_auth()
- async def on_GET(self, request, room_alias):
- room_alias = RoomAlias.from_string(room_alias)
+ async def on_GET(self, request: Request, room_alias: str) -> Tuple[int, JsonDict]:
+ room_alias_obj = RoomAlias.from_string(room_alias)
- res = await self.directory_handler.get_association(room_alias)
+ res = await self.directory_handler.get_association(room_alias_obj)
return 200, res
- async def on_PUT(self, request, room_alias):
- room_alias = RoomAlias.from_string(room_alias)
+ async def on_PUT(
+ self, request: SynapseRequest, room_alias: str
+ ) -> Tuple[int, JsonDict]:
+ room_alias_obj = RoomAlias.from_string(room_alias)
content = parse_json_object_from_request(request)
if "room_id" not in content:
@@ -61,7 +70,7 @@ async def on_PUT(self, request, room_alias):
)
logger.debug("Got content: %s", content)
- logger.debug("Got room name: %s", room_alias.to_string())
+ logger.debug("Got room name: %s", room_alias_obj.to_string())
room_id = content["room_id"]
servers = content["servers"] if "servers" in content else None
@@ -78,22 +87,25 @@ async def on_PUT(self, request, room_alias):
requester = await self.auth.get_user_by_req(request)
await self.directory_handler.create_association(
- requester, room_alias, room_id, servers
+ requester, room_alias_obj, room_id, servers
)
return 200, {}
- async def on_DELETE(self, request, room_alias):
+ async def on_DELETE(
+ self, request: SynapseRequest, room_alias: str
+ ) -> Tuple[int, JsonDict]:
+ room_alias_obj = RoomAlias.from_string(room_alias)
+
try:
service = self.auth.get_appservice_by_req(request)
- room_alias = RoomAlias.from_string(room_alias)
await self.directory_handler.delete_appservice_association(
- service, room_alias
+ service, room_alias_obj
)
logger.info(
"Application service at %s deleted alias %s",
service.url,
- room_alias.to_string(),
+ room_alias_obj.to_string(),
)
return 200, {}
except InvalidClientCredentialsError:
@@ -103,12 +115,10 @@ async def on_DELETE(self, request, room_alias):
requester = await self.auth.get_user_by_req(request)
user = requester.user
- room_alias = RoomAlias.from_string(room_alias)
-
- await self.directory_handler.delete_association(requester, room_alias)
+ await self.directory_handler.delete_association(requester, room_alias_obj)
logger.info(
- "User %s deleted alias %s", user.to_string(), room_alias.to_string()
+ "User %s deleted alias %s", user.to_string(), room_alias_obj.to_string()
)
return 200, {}
@@ -117,20 +127,22 @@ async def on_DELETE(self, request, room_alias):
class ClientDirectoryListServer(RestServlet):
PATTERNS = client_patterns("/directory/list/room/(?P[^/]*)$", v1=True)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.store = hs.get_datastore()
self.directory_handler = hs.get_directory_handler()
self.auth = hs.get_auth()
- async def on_GET(self, request, room_id):
+ async def on_GET(self, request: Request, room_id: str) -> Tuple[int, JsonDict]:
room = await self.store.get_room(room_id)
if room is None:
raise NotFoundError("Unknown room")
return 200, {"visibility": "public" if room["is_public"] else "private"}
- async def on_PUT(self, request, room_id):
+ async def on_PUT(
+ self, request: SynapseRequest, room_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
content = parse_json_object_from_request(request)
@@ -142,7 +154,9 @@ async def on_PUT(self, request, room_id):
return 200, {}
- async def on_DELETE(self, request, room_id):
+ async def on_DELETE(
+ self, request: SynapseRequest, room_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await self.directory_handler.edit_published_room_list(
@@ -157,21 +171,27 @@ class ClientAppserviceDirectoryListServer(RestServlet):
"/directory/list/appservice/(?P[^/]*)/(?P[^/]*)$", v1=True
)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.store = hs.get_datastore()
self.directory_handler = hs.get_directory_handler()
self.auth = hs.get_auth()
- def on_PUT(self, request, network_id, room_id):
+ async def on_PUT(
+ self, request: SynapseRequest, network_id: str, room_id: str
+ ) -> Tuple[int, JsonDict]:
content = parse_json_object_from_request(request)
visibility = content.get("visibility", "public")
- return self._edit(request, network_id, room_id, visibility)
+ return await self._edit(request, network_id, room_id, visibility)
- def on_DELETE(self, request, network_id, room_id):
- return self._edit(request, network_id, room_id, "private")
+ async def on_DELETE(
+ self, request: SynapseRequest, network_id: str, room_id: str
+ ) -> Tuple[int, JsonDict]:
+ return await self._edit(request, network_id, room_id, "private")
- async def _edit(self, request, network_id, room_id, visibility):
+ async def _edit(
+ self, request: SynapseRequest, network_id: str, room_id: str, visibility: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
if not requester.app_service:
raise AuthError(
diff --git a/synapse/rest/client/events.py b/synapse/rest/client/events.py
index 52bb579cfd40..13b72a045a4a 100644
--- a/synapse/rest/client/events.py
+++ b/synapse/rest/client/events.py
@@ -14,11 +14,18 @@
"""This module contains REST servlets to do with event streaming, /events."""
import logging
+from typing import TYPE_CHECKING, Dict, List, Tuple, Union
from synapse.api.errors import SynapseError
-from synapse.http.servlet import RestServlet
+from synapse.http.server import HttpServer
+from synapse.http.servlet import RestServlet, parse_string
+from synapse.http.site import SynapseRequest
from synapse.rest.client._base import client_patterns
from synapse.streams.config import PaginationConfig
+from synapse.types import JsonDict
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@@ -28,31 +35,30 @@ class EventStreamRestServlet(RestServlet):
DEFAULT_LONGPOLL_TIME_MS = 30000
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.event_stream_handler = hs.get_event_stream_handler()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
- async def on_GET(self, request):
+ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
is_guest = requester.is_guest
- room_id = None
+ args: Dict[bytes, List[bytes]] = request.args # type: ignore
if is_guest:
- if b"room_id" not in request.args:
+ if b"room_id" not in args:
raise SynapseError(400, "Guest users must specify room_id param")
- if b"room_id" in request.args:
- room_id = request.args[b"room_id"][0].decode("ascii")
+ room_id = parse_string(request, "room_id")
pagin_config = await PaginationConfig.from_request(self.store, request)
timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS
- if b"timeout" in request.args:
+ if b"timeout" in args:
try:
- timeout = int(request.args[b"timeout"][0])
+ timeout = int(args[b"timeout"][0])
except ValueError:
raise SynapseError(400, "timeout must be in milliseconds.")
- as_client_event = b"raw" not in request.args
+ as_client_event = b"raw" not in args
chunk = await self.event_stream_handler.get_stream(
requester.user.to_string(),
@@ -70,25 +76,27 @@ async def on_GET(self, request):
class EventRestServlet(RestServlet):
PATTERNS = client_patterns("/events/(?P[^/]*)$", v1=True)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.clock = hs.get_clock()
self.event_handler = hs.get_event_handler()
self.auth = hs.get_auth()
self._event_serializer = hs.get_event_client_serializer()
- async def on_GET(self, request, event_id):
+ async def on_GET(
+ self, request: SynapseRequest, event_id: str
+ ) -> Tuple[int, Union[str, JsonDict]]:
requester = await self.auth.get_user_by_req(request)
event = await self.event_handler.get_event(requester.user, None, event_id)
time_now = self.clock.time_msec()
if event:
- event = await self._event_serializer.serialize_event(event, time_now)
- return 200, event
+ result = await self._event_serializer.serialize_event(event, time_now)
+ return 200, result
else:
return 404, "Event not found."
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
EventStreamRestServlet(hs).register(http_server)
EventRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/filter.py b/synapse/rest/client/filter.py
index 411667a9c8d9..6ed60c74181f 100644
--- a/synapse/rest/client/filter.py
+++ b/synapse/rest/client/filter.py
@@ -13,26 +13,34 @@
# limitations under the License.
import logging
+from typing import TYPE_CHECKING, Tuple
from synapse.api.errors import AuthError, NotFoundError, StoreError, SynapseError
+from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_json_object_from_request
-from synapse.types import UserID
+from synapse.http.site import SynapseRequest
+from synapse.types import JsonDict, UserID
from ._base import client_patterns, set_timeline_upper_limit
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
class GetFilterRestServlet(RestServlet):
PATTERNS = client_patterns("/user/(?P[^/]*)/filter/(?P[^/]*)")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self.filtering = hs.get_filtering()
- async def on_GET(self, request, user_id, filter_id):
+ async def on_GET(
+ self, request: SynapseRequest, user_id: str, filter_id: str
+ ) -> Tuple[int, JsonDict]:
target_user = UserID.from_string(user_id)
requester = await self.auth.get_user_by_req(request)
@@ -43,13 +51,13 @@ async def on_GET(self, request, user_id, filter_id):
raise AuthError(403, "Can only get filters for local users")
try:
- filter_id = int(filter_id)
+ filter_id_int = int(filter_id)
except Exception:
raise SynapseError(400, "Invalid filter_id")
try:
filter_collection = await self.filtering.get_user_filter(
- user_localpart=target_user.localpart, filter_id=filter_id
+ user_localpart=target_user.localpart, filter_id=filter_id_int
)
except StoreError as e:
if e.code != 404:
@@ -62,13 +70,15 @@ async def on_GET(self, request, user_id, filter_id):
class CreateFilterRestServlet(RestServlet):
PATTERNS = client_patterns("/user/(?P[^/]*)/filter")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self.filtering = hs.get_filtering()
- async def on_POST(self, request, user_id):
+ async def on_POST(
+ self, request: SynapseRequest, user_id: str
+ ) -> Tuple[int, JsonDict]:
target_user = UserID.from_string(user_id)
requester = await self.auth.get_user_by_req(request)
@@ -89,6 +99,6 @@ async def on_POST(self, request, user_id):
return 200, {"filter_id": str(filter_id)}
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
GetFilterRestServlet(hs).register(http_server)
CreateFilterRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/groups.py b/synapse/rest/client/groups.py
index 6285680c00c8..c3667ff8aaad 100644
--- a/synapse/rest/client/groups.py
+++ b/synapse/rest/client/groups.py
@@ -26,6 +26,7 @@
)
from synapse.api.errors import Codes, SynapseError
from synapse.handlers.groups_local import GroupsLocalHandler
+from synapse.http.server import HttpServer
from synapse.http.servlet import (
RestServlet,
assert_params_in_dict,
@@ -930,7 +931,7 @@ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
return 200, result
-def register_servlets(hs: "HomeServer", http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
GroupServlet(hs).register(http_server)
GroupSummaryServlet(hs).register(http_server)
GroupInvitedUsersServlet(hs).register(http_server)
diff --git a/synapse/rest/client/initial_sync.py b/synapse/rest/client/initial_sync.py
index 12ba0e91dbd1..49b1037b28d4 100644
--- a/synapse/rest/client/initial_sync.py
+++ b/synapse/rest/client/initial_sync.py
@@ -12,25 +12,33 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import TYPE_CHECKING, Dict, List, Tuple
+from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_boolean
+from synapse.http.site import SynapseRequest
from synapse.rest.client._base import client_patterns
from synapse.streams.config import PaginationConfig
+from synapse.types import JsonDict
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
# TODO: Needs unit testing
class InitialSyncRestServlet(RestServlet):
PATTERNS = client_patterns("/initialSync$", v1=True)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.initial_sync_handler = hs.get_initial_sync_handler()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
- async def on_GET(self, request):
+ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
- as_client_event = b"raw" not in request.args
+ args: Dict[bytes, List[bytes]] = request.args # type: ignore
+ as_client_event = b"raw" not in args
pagination_config = await PaginationConfig.from_request(self.store, request)
include_archived = parse_boolean(request, "archived", default=False)
content = await self.initial_sync_handler.snapshot_all_rooms(
@@ -43,5 +51,5 @@ async def on_GET(self, request):
return 200, content
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
InitialSyncRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/keys.py b/synapse/rest/client/keys.py
index d0d9d30d40a3..7281b2ee2912 100644
--- a/synapse/rest/client/keys.py
+++ b/synapse/rest/client/keys.py
@@ -15,19 +15,25 @@
# limitations under the License.
import logging
+from typing import TYPE_CHECKING, Any, Optional, Tuple
-from synapse.api.errors import SynapseError
+from synapse.api.errors import InvalidAPICallError, SynapseError
+from synapse.http.server import HttpServer
from synapse.http.servlet import (
RestServlet,
parse_integer,
parse_json_object_from_request,
parse_string,
)
+from synapse.http.site import SynapseRequest
from synapse.logging.opentracing import log_kv, set_tag, trace
-from synapse.types import StreamToken
+from synapse.types import JsonDict, StreamToken
from ._base import client_patterns, interactive_auth_handler
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
@@ -59,18 +65,16 @@ class KeyUploadServlet(RestServlet):
PATTERNS = client_patterns("/keys/upload(/(?P[^/]+))?$")
- def __init__(self, hs):
- """
- Args:
- hs (synapse.server.HomeServer): server
- """
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.e2e_keys_handler = hs.get_e2e_keys_handler()
self.device_handler = hs.get_device_handler()
@trace(opname="upload_keys")
- async def on_POST(self, request, device_id):
+ async def on_POST(
+ self, request: SynapseRequest, device_id: Optional[str]
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
user_id = requester.user.to_string()
body = parse_json_object_from_request(request)
@@ -148,21 +152,30 @@ class KeyQueryServlet(RestServlet):
PATTERNS = client_patterns("/keys/query$")
- def __init__(self, hs):
- """
- Args:
- hs (synapse.server.HomeServer):
- """
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.e2e_keys_handler = hs.get_e2e_keys_handler()
- async def on_POST(self, request):
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
user_id = requester.user.to_string()
device_id = requester.device_id
timeout = parse_integer(request, "timeout", 10 * 1000)
body = parse_json_object_from_request(request)
+
+ device_keys = body.get("device_keys")
+ if not isinstance(device_keys, dict):
+ raise InvalidAPICallError("'device_keys' must be a JSON object")
+
+ def is_list_of_strings(values: Any) -> bool:
+ return isinstance(values, list) and all(isinstance(v, str) for v in values)
+
+ if any(not is_list_of_strings(keys) for keys in device_keys.values()):
+ raise InvalidAPICallError(
+ "'device_keys' values must be a list of strings",
+ )
+
result = await self.e2e_keys_handler.query_devices(
body, timeout, user_id, device_id
)
@@ -181,17 +194,13 @@ class KeyChangesServlet(RestServlet):
PATTERNS = client_patterns("/keys/changes$")
- def __init__(self, hs):
- """
- Args:
- hs (synapse.server.HomeServer):
- """
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler()
self.store = hs.get_datastore()
- async def on_GET(self, request):
+ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
from_token_string = parse_string(request, "from", required=True)
@@ -231,12 +240,12 @@ class OneTimeKeyServlet(RestServlet):
PATTERNS = client_patterns("/keys/claim$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.e2e_keys_handler = hs.get_e2e_keys_handler()
- async def on_POST(self, request):
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
await self.auth.get_user_by_req(request, allow_guest=True)
timeout = parse_integer(request, "timeout", 10 * 1000)
body = parse_json_object_from_request(request)
@@ -255,11 +264,7 @@ class SigningKeyUploadServlet(RestServlet):
PATTERNS = client_patterns("/keys/device_signing/upload$", releases=())
- def __init__(self, hs):
- """
- Args:
- hs (synapse.server.HomeServer): server
- """
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.auth = hs.get_auth()
@@ -267,7 +272,7 @@ def __init__(self, hs):
self.auth_handler = hs.get_auth_handler()
@interactive_auth_handler
- async def on_POST(self, request):
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
body = parse_json_object_from_request(request)
@@ -315,16 +320,12 @@ class SignaturesUploadServlet(RestServlet):
PATTERNS = client_patterns("/keys/signatures/upload$")
- def __init__(self, hs):
- """
- Args:
- hs (synapse.server.HomeServer): server
- """
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.e2e_keys_handler = hs.get_e2e_keys_handler()
- async def on_POST(self, request):
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
user_id = requester.user.to_string()
body = parse_json_object_from_request(request)
@@ -335,7 +336,7 @@ async def on_POST(self, request):
return 200, result
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
KeyUploadServlet(hs).register(http_server)
KeyQueryServlet(hs).register(http_server)
KeyChangesServlet(hs).register(http_server)
diff --git a/synapse/rest/client/knock.py b/synapse/rest/client/knock.py
index 7d1bc40658a6..68fb08d0ba0f 100644
--- a/synapse/rest/client/knock.py
+++ b/synapse/rest/client/knock.py
@@ -19,6 +19,7 @@
from synapse.api.constants import Membership
from synapse.api.errors import SynapseError
+from synapse.http.server import HttpServer
from synapse.http.servlet import (
RestServlet,
parse_json_object_from_request,
@@ -103,5 +104,5 @@ def on_PUT(self, request: Request, room_identifier: str, txn_id: str):
)
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
KnockRoomAliasServlet(hs).register(http_server)
diff --git a/synapse/rest/client/login.py b/synapse/rest/client/login.py
index 0c8d8967b7ee..4be502a77b04 100644
--- a/synapse/rest/client/login.py
+++ b/synapse/rest/client/login.py
@@ -1,4 +1,4 @@
-# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2014-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,7 +14,7 @@
import logging
import re
-from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Optional
+from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Optional, Tuple
from typing_extensions import TypedDict
@@ -104,7 +104,13 @@ def __init__(self, hs: "HomeServer"):
burst_count=self.hs.config.rc_login_account.burst_count,
)
- def on_GET(self, request: SynapseRequest):
+ # ensure the CAS/SAML/OIDC handlers are loaded on this worker instance.
+ # The reason for this is to ensure that the auth_provider_ids are registered
+ # with SsoHandler, which in turn ensures that the login/registration prometheus
+ # counters are initialised for the auth_provider_ids.
+ _load_sso_handlers(hs)
+
+ def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
flows = []
if self.jwt_enabled:
flows.append({"type": LoginRestServlet.JWT_TYPE})
@@ -151,7 +157,7 @@ def on_GET(self, request: SynapseRequest):
return 200, {"flows": flows}
- async def on_POST(self, request: SynapseRequest):
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, LoginResponse]:
login_submission = parse_json_object_from_request(request)
if self._msc2918_enabled:
@@ -211,7 +217,7 @@ async def _do_appservice_login(
login_submission: JsonDict,
appservice: ApplicationService,
should_issue_refresh_token: bool = False,
- ):
+ ) -> LoginResponse:
identifier = login_submission.get("identifier")
logger.info("Got appservice login request with identifier: %r", identifier)
@@ -461,10 +467,7 @@ def __init__(self, hs: "HomeServer"):
self._clock = hs.get_clock()
self.access_token_lifetime = hs.config.access_token_lifetime
- async def on_POST(
- self,
- request: SynapseRequest,
- ):
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
refresh_submission = parse_json_object_from_request(request)
assert_params_in_dict(refresh_submission, ["refresh_token"])
@@ -499,12 +502,7 @@ class SsoRedirectServlet(RestServlet):
def __init__(self, hs: "HomeServer"):
# make sure that the relevant handlers are instantiated, so that they
# register themselves with the main SSOHandler.
- if hs.config.cas_enabled:
- hs.get_cas_handler()
- if hs.config.saml2_enabled:
- hs.get_saml_handler()
- if hs.config.oidc_enabled:
- hs.get_oidc_handler()
+ _load_sso_handlers(hs)
self._sso_handler = hs.get_sso_handler()
self._msc2858_enabled = hs.config.experimental.msc2858_enabled
self._public_baseurl = hs.config.public_baseurl
@@ -569,7 +567,7 @@ async def on_GET(
class CasTicketServlet(RestServlet):
PATTERNS = client_patterns("/login/cas/ticket", v1=True)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self._cas_handler = hs.get_cas_handler()
@@ -591,10 +589,26 @@ async def on_GET(self, request: SynapseRequest) -> None:
)
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
LoginRestServlet(hs).register(http_server)
if hs.config.access_token_lifetime is not None:
RefreshTokenServlet(hs).register(http_server)
SsoRedirectServlet(hs).register(http_server)
if hs.config.cas_enabled:
CasTicketServlet(hs).register(http_server)
+
+
+def _load_sso_handlers(hs: "HomeServer") -> None:
+ """Ensure that the SSO handlers are loaded, if they are enabled by configuration.
+
+ This is mostly useful to ensure that the CAS/SAML/OIDC handlers register themselves
+ with the main SsoHandler.
+
+ It's safe to call this multiple times.
+ """
+ if hs.config.cas.cas_enabled:
+ hs.get_cas_handler()
+ if hs.config.saml2.saml2_enabled:
+ hs.get_saml_handler()
+ if hs.config.oidc.oidc_enabled:
+ hs.get_oidc_handler()
diff --git a/synapse/rest/client/logout.py b/synapse/rest/client/logout.py
index 6055cac2bd0a..193a6951b91b 100644
--- a/synapse/rest/client/logout.py
+++ b/synapse/rest/client/logout.py
@@ -13,9 +13,16 @@
# limitations under the License.
import logging
+from typing import TYPE_CHECKING, Tuple
+from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet
+from synapse.http.site import SynapseRequest
from synapse.rest.client._base import client_patterns
+from synapse.types import JsonDict
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@@ -23,13 +30,13 @@
class LogoutRestServlet(RestServlet):
PATTERNS = client_patterns("/logout$", v1=True)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self._auth_handler = hs.get_auth_handler()
self._device_handler = hs.get_device_handler()
- async def on_POST(self, request):
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_expired=True)
if requester.device_id is None:
@@ -48,13 +55,13 @@ async def on_POST(self, request):
class LogoutAllRestServlet(RestServlet):
PATTERNS = client_patterns("/logout/all$", v1=True)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self._auth_handler = hs.get_auth_handler()
self._device_handler = hs.get_device_handler()
- async def on_POST(self, request):
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_expired=True)
user_id = requester.user.to_string()
@@ -67,6 +74,6 @@ async def on_POST(self, request):
return 200, {}
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
LogoutRestServlet(hs).register(http_server)
LogoutAllRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/notifications.py b/synapse/rest/client/notifications.py
index 0ede643c2d91..d1d8a984c630 100644
--- a/synapse/rest/client/notifications.py
+++ b/synapse/rest/client/notifications.py
@@ -13,26 +13,33 @@
# limitations under the License.
import logging
+from typing import TYPE_CHECKING, Tuple
from synapse.events.utils import format_event_for_client_v2_without_room_id
+from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_integer, parse_string
+from synapse.http.site import SynapseRequest
+from synapse.types import JsonDict
from ._base import client_patterns
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
class NotificationsServlet(RestServlet):
PATTERNS = client_patterns("/notifications$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.store = hs.get_datastore()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self._event_serializer = hs.get_event_client_serializer()
- async def on_GET(self, request):
+ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
@@ -87,5 +94,5 @@ async def on_GET(self, request):
return 200, {"notifications": returned_push_actions, "next_token": next_token}
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
NotificationsServlet(hs).register(http_server)
diff --git a/synapse/rest/client/openid.py b/synapse/rest/client/openid.py
index e8d2673819cb..4dda6dce4ba1 100644
--- a/synapse/rest/client/openid.py
+++ b/synapse/rest/client/openid.py
@@ -12,15 +12,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
import logging
+from typing import TYPE_CHECKING, Tuple
from synapse.api.errors import AuthError
+from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_json_object_from_request
+from synapse.http.site import SynapseRequest
+from synapse.types import JsonDict
from synapse.util.stringutils import random_string
from ._base import client_patterns
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
@@ -58,14 +64,16 @@ class IdTokenServlet(RestServlet):
EXPIRES_MS = 3600 * 1000
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.clock = hs.get_clock()
self.server_name = hs.config.server_name
- async def on_POST(self, request, user_id):
+ async def on_POST(
+ self, request: SynapseRequest, user_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
if user_id != requester.user.to_string():
raise AuthError(403, "Cannot request tokens for other users.")
@@ -90,5 +98,5 @@ async def on_POST(self, request, user_id):
)
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
IdTokenServlet(hs).register(http_server)
diff --git a/synapse/rest/client/password_policy.py b/synapse/rest/client/password_policy.py
index a83927aee641..6d64efb1651d 100644
--- a/synapse/rest/client/password_policy.py
+++ b/synapse/rest/client/password_policy.py
@@ -13,28 +13,32 @@
# limitations under the License.
import logging
+from typing import TYPE_CHECKING, Tuple
+from twisted.web.server import Request
+
+from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet
+from synapse.types import JsonDict
from ._base import client_patterns
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
class PasswordPolicyServlet(RestServlet):
PATTERNS = client_patterns("/password_policy$")
- def __init__(self, hs):
- """
- Args:
- hs (synapse.server.HomeServer): server
- """
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.policy = hs.config.password_policy
self.enabled = hs.config.password_policy_enabled
- def on_GET(self, request):
+ def on_GET(self, request: Request) -> Tuple[int, JsonDict]:
if not self.enabled or not self.policy:
return (200, {})
@@ -53,5 +57,5 @@ def on_GET(self, request):
return (200, policy)
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
PasswordPolicyServlet(hs).register(http_server)
diff --git a/synapse/rest/client/presence.py b/synapse/rest/client/presence.py
index 6c27e5faf986..94dd4fe2f45c 100644
--- a/synapse/rest/client/presence.py
+++ b/synapse/rest/client/presence.py
@@ -15,12 +15,18 @@
""" This module contains REST servlets to do with presence: /presence/
"""
import logging
+from typing import TYPE_CHECKING, Tuple
from synapse.api.errors import AuthError, SynapseError
from synapse.handlers.presence import format_user_presence_state
+from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_json_object_from_request
+from synapse.http.site import SynapseRequest
from synapse.rest.client._base import client_patterns
-from synapse.types import UserID
+from synapse.types import JsonDict, UserID
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@@ -28,7 +34,7 @@
class PresenceStatusRestServlet(RestServlet):
PATTERNS = client_patterns("/presence/(?P[^/]*)/status", v1=True)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.presence_handler = hs.get_presence_handler()
@@ -37,7 +43,9 @@ def __init__(self, hs):
self._use_presence = hs.config.server.use_presence
- async def on_GET(self, request, user_id):
+ async def on_GET(
+ self, request: SynapseRequest, user_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
user = UserID.from_string(user_id)
@@ -53,13 +61,15 @@ async def on_GET(self, request, user_id):
raise AuthError(403, "You are not allowed to see their presence.")
state = await self.presence_handler.get_state(target_user=user)
- state = format_user_presence_state(
+ result = format_user_presence_state(
state, self.clock.time_msec(), include_user_id=False
)
- return 200, state
+ return 200, result
- async def on_PUT(self, request, user_id):
+ async def on_PUT(
+ self, request: SynapseRequest, user_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
user = UserID.from_string(user_id)
@@ -91,5 +101,5 @@ async def on_PUT(self, request, user_id):
return 200, {}
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
PresenceStatusRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/profile.py b/synapse/rest/client/profile.py
index 5463ed2c4f85..d0f20de569c0 100644
--- a/synapse/rest/client/profile.py
+++ b/synapse/rest/client/profile.py
@@ -14,22 +14,31 @@
""" This module contains REST servlets to do with profile: /profile/ """
+from typing import TYPE_CHECKING, Tuple
+
from synapse.api.errors import Codes, SynapseError
+from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_json_object_from_request
+from synapse.http.site import SynapseRequest
from synapse.rest.client._base import client_patterns
-from synapse.types import UserID
+from synapse.types import JsonDict, UserID
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
class ProfileDisplaynameRestServlet(RestServlet):
PATTERNS = client_patterns("/profile/(?P[^/]*)/displayname", v1=True)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.profile_handler = hs.get_profile_handler()
self.auth = hs.get_auth()
- async def on_GET(self, request, user_id):
+ async def on_GET(
+ self, request: SynapseRequest, user_id: str
+ ) -> Tuple[int, JsonDict]:
requester_user = None
if self.hs.config.require_auth_for_profile_requests:
@@ -48,7 +57,9 @@ async def on_GET(self, request, user_id):
return 200, ret
- async def on_PUT(self, request, user_id):
+ async def on_PUT(
+ self, request: SynapseRequest, user_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
user = UserID.from_string(user_id)
is_admin = await self.auth.is_server_admin(requester.user)
@@ -72,13 +83,15 @@ async def on_PUT(self, request, user_id):
class ProfileAvatarURLRestServlet(RestServlet):
PATTERNS = client_patterns("/profile/(?P[^/]*)/avatar_url", v1=True)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.profile_handler = hs.get_profile_handler()
self.auth = hs.get_auth()
- async def on_GET(self, request, user_id):
+ async def on_GET(
+ self, request: SynapseRequest, user_id: str
+ ) -> Tuple[int, JsonDict]:
requester_user = None
if self.hs.config.require_auth_for_profile_requests:
@@ -97,7 +110,9 @@ async def on_GET(self, request, user_id):
return 200, ret
- async def on_PUT(self, request, user_id):
+ async def on_PUT(
+ self, request: SynapseRequest, user_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
user = UserID.from_string(user_id)
is_admin = await self.auth.is_server_admin(requester.user)
@@ -120,13 +135,15 @@ async def on_PUT(self, request, user_id):
class ProfileRestServlet(RestServlet):
PATTERNS = client_patterns("/profile/(?P[^/]*)", v1=True)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.profile_handler = hs.get_profile_handler()
self.auth = hs.get_auth()
- async def on_GET(self, request, user_id):
+ async def on_GET(
+ self, request: SynapseRequest, user_id: str
+ ) -> Tuple[int, JsonDict]:
requester_user = None
if self.hs.config.require_auth_for_profile_requests:
@@ -149,7 +166,7 @@ async def on_GET(self, request, user_id):
return 200, ret
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
ProfileDisplaynameRestServlet(hs).register(http_server)
ProfileAvatarURLRestServlet(hs).register(http_server)
ProfileRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/pusher.py b/synapse/rest/client/pusher.py
index 84619c5e4184..98604a93887f 100644
--- a/synapse/rest/client/pusher.py
+++ b/synapse/rest/client/pusher.py
@@ -13,17 +13,23 @@
# limitations under the License.
import logging
+from typing import TYPE_CHECKING, Tuple
from synapse.api.errors import Codes, StoreError, SynapseError
-from synapse.http.server import respond_with_html_bytes
+from synapse.http.server import HttpServer, respond_with_html_bytes
from synapse.http.servlet import (
RestServlet,
assert_params_in_dict,
parse_json_object_from_request,
parse_string,
)
+from synapse.http.site import SynapseRequest
from synapse.push import PusherConfigException
from synapse.rest.client._base import client_patterns
+from synapse.types import JsonDict
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@@ -31,12 +37,12 @@
class PushersRestServlet(RestServlet):
PATTERNS = client_patterns("/pushers$", v1=True)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.auth = hs.get_auth()
- async def on_GET(self, request):
+ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
user = requester.user
@@ -50,14 +56,14 @@ async def on_GET(self, request):
class PushersSetRestServlet(RestServlet):
PATTERNS = client_patterns("/pushers/set$", v1=True)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self.notifier = hs.get_notifier()
self.pusher_pool = self.hs.get_pusherpool()
- async def on_POST(self, request):
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
user = requester.user
@@ -132,14 +138,14 @@ class PushersRemoveRestServlet(RestServlet):
PATTERNS = client_patterns("/pushers/remove$", v1=True)
SUCCESS_HTML = b"You have been unsubscribed"
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.notifier = hs.get_notifier()
self.auth = hs.get_auth()
self.pusher_pool = self.hs.get_pusherpool()
- async def on_GET(self, request):
+ async def on_GET(self, request: SynapseRequest) -> None:
requester = await self.auth.get_user_by_req(request, rights="delete_pusher")
user = requester.user
@@ -165,7 +171,7 @@ async def on_GET(self, request):
return None
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
PushersRestServlet(hs).register(http_server)
PushersSetRestServlet(hs).register(http_server)
PushersRemoveRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/read_marker.py b/synapse/rest/client/read_marker.py
index 027f8b81fa93..43c04fac6fdb 100644
--- a/synapse/rest/client/read_marker.py
+++ b/synapse/rest/client/read_marker.py
@@ -13,27 +13,36 @@
# limitations under the License.
import logging
+from typing import TYPE_CHECKING, Tuple
from synapse.api.constants import ReadReceiptEventFields
from synapse.api.errors import Codes, SynapseError
+from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_json_object_from_request
+from synapse.http.site import SynapseRequest
+from synapse.types import JsonDict
from ._base import client_patterns
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
class ReadMarkerRestServlet(RestServlet):
PATTERNS = client_patterns("/rooms/(?P[^/]*)/read_markers$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.receipts_handler = hs.get_receipts_handler()
self.read_marker_handler = hs.get_read_marker_handler()
self.presence_handler = hs.get_presence_handler()
- async def on_POST(self, request, room_id):
+ async def on_POST(
+ self, request: SynapseRequest, room_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await self.presence_handler.bump_presence_active_time(requester.user)
@@ -70,5 +79,5 @@ async def on_POST(self, request, room_id):
return 200, {}
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
ReadMarkerRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py
index b0e672c9bed7..bcb04f0efe04 100644
--- a/synapse/rest/client/register.py
+++ b/synapse/rest/client/register.py
@@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import hmac
import logging
import random
from typing import List, Union
@@ -28,6 +27,7 @@
ThreepidValidationError,
UnrecognizedRequestError,
)
+from synapse.api.ratelimiting import Ratelimiter
from synapse.config import ConfigError
from synapse.config.captcha import CaptchaConfig
from synapse.config.consent import ConsentConfig
@@ -59,18 +59,6 @@
from ._base import client_patterns, interactive_auth_handler
-# We ought to be using hmac.compare_digest() but on older pythons it doesn't
-# exist. It's a _really minor_ security flaw to use plain string comparison
-# because the timing attack is so obscured by all the other code here it's
-# unlikely to make much difference
-if hasattr(hmac, "compare_digest"):
- compare_digest = hmac.compare_digest
-else:
-
- def compare_digest(a, b):
- return a == b
-
-
logger = logging.getLogger(__name__)
@@ -374,6 +362,55 @@ async def on_GET(self, request):
return 200, {"available": True}
+class RegistrationTokenValidityRestServlet(RestServlet):
+ """Check the validity of a registration token.
+
+ Example:
+
+ GET /_matrix/client/unstable/org.matrix.msc3231/register/org.matrix.msc3231.login.registration_token/validity?token=abcd
+
+ 200 OK
+
+ {
+ "valid": true
+ }
+ """
+
+ PATTERNS = client_patterns(
+ f"/org.matrix.msc3231/register/{LoginType.REGISTRATION_TOKEN}/validity",
+ releases=(),
+ unstable=True,
+ )
+
+ def __init__(self, hs):
+ """
+ Args:
+ hs (synapse.server.HomeServer): server
+ """
+ super().__init__()
+ self.hs = hs
+ self.store = hs.get_datastore()
+ self.ratelimiter = Ratelimiter(
+ store=self.store,
+ clock=hs.get_clock(),
+ rate_hz=hs.config.ratelimiting.rc_registration_token_validity.per_second,
+ burst_count=hs.config.ratelimiting.rc_registration_token_validity.burst_count,
+ )
+
+ async def on_GET(self, request):
+ await self.ratelimiter.ratelimit(None, (request.getClientIP(),))
+
+ if not self.hs.config.enable_registration:
+ raise SynapseError(
+ 403, "Registration has been disabled", errcode=Codes.FORBIDDEN
+ )
+
+ token = parse_string(request, "token", required=True)
+ valid = await self.store.registration_token_is_valid(token)
+
+ return 200, {"valid": valid}
+
+
class RegisterRestServlet(RestServlet):
PATTERNS = client_patterns("/register$")
@@ -681,6 +718,22 @@ async def on_POST(self, request):
)
if registered:
+ # Check if a token was used to authenticate registration
+ registration_token = await self.auth_handler.get_session_data(
+ session_id,
+ UIAuthSessionDataConstants.REGISTRATION_TOKEN,
+ )
+ if registration_token:
+ # Increment the `completed` counter for the token
+ await self.store.use_registration_token(registration_token)
+ # Indicate that the token has been successfully used so that
+ # pending is not decremented again when expiring old UIA sessions.
+ await self.store.mark_ui_auth_stage_complete(
+ session_id,
+ LoginType.REGISTRATION_TOKEN,
+ True,
+ )
+
await self.registration_handler.post_registration_actions(
user_id=registered_user_id,
auth_result=auth_result,
@@ -863,6 +916,11 @@ def _calculate_registration_flows(
for flow in flows:
flow.insert(0, LoginType.RECAPTCHA)
+ # Prepend registration token to all flows if we're requiring a token
+ if config.registration_requires_token:
+ for flow in flows:
+ flow.insert(0, LoginType.REGISTRATION_TOKEN)
+
return flows
@@ -871,4 +929,5 @@ def register_servlets(hs, http_server):
MsisdnRegisterRequestTokenRestServlet(hs).register(http_server)
UsernameAvailabilityRestServlet(hs).register(http_server)
RegistrationSubmitTokenServlet(hs).register(http_server)
+ RegistrationTokenValidityRestServlet(hs).register(http_server)
RegisterRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/room_upgrade_rest_servlet.py b/synapse/rest/client/room_upgrade_rest_servlet.py
index 6d1b083acb47..6a7792e18b2e 100644
--- a/synapse/rest/client/room_upgrade_rest_servlet.py
+++ b/synapse/rest/client/room_upgrade_rest_servlet.py
@@ -13,18 +13,25 @@
# limitations under the License.
import logging
+from typing import TYPE_CHECKING, Tuple
from synapse.api.errors import Codes, ShadowBanError, SynapseError
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
+from synapse.http.server import HttpServer
from synapse.http.servlet import (
RestServlet,
assert_params_in_dict,
parse_json_object_from_request,
)
+from synapse.http.site import SynapseRequest
+from synapse.types import JsonDict
from synapse.util import stringutils
from ._base import client_patterns
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
@@ -41,9 +48,6 @@ class RoomUpgradeRestServlet(RestServlet):
}
Creates a new room and shuts down the old one. Returns the ID of the new room.
-
- Args:
- hs (synapse.server.HomeServer):
"""
PATTERNS = client_patterns(
@@ -51,13 +55,15 @@ class RoomUpgradeRestServlet(RestServlet):
"/rooms/(?P[^/]*)/upgrade$"
)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self._hs = hs
self._room_creation_handler = hs.get_room_creation_handler()
self._auth = hs.get_auth()
- async def on_POST(self, request, room_id):
+ async def on_POST(
+ self, request: SynapseRequest, room_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self._auth.get_user_by_req(request)
content = parse_json_object_from_request(request)
@@ -84,5 +90,5 @@ async def on_POST(self, request, room_id):
return 200, ret
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
RoomUpgradeRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/shared_rooms.py b/synapse/rest/client/shared_rooms.py
index d2e7f04b406c..1d90493eb082 100644
--- a/synapse/rest/client/shared_rooms.py
+++ b/synapse/rest/client/shared_rooms.py
@@ -12,13 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+from typing import TYPE_CHECKING, Tuple
from synapse.api.errors import Codes, SynapseError
+from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet
-from synapse.types import UserID
+from synapse.http.site import SynapseRequest
+from synapse.types import JsonDict, UserID
from ._base import client_patterns
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
@@ -32,13 +38,15 @@ class UserSharedRoomsServlet(RestServlet):
releases=(), # This is an unstable feature
)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.user_directory_active = hs.config.update_user_directory
- async def on_GET(self, request, user_id):
+ async def on_GET(
+ self, request: SynapseRequest, user_id: str
+ ) -> Tuple[int, JsonDict]:
if not self.user_directory_active:
raise SynapseError(
@@ -63,5 +71,5 @@ async def on_GET(self, request, user_id):
return 200, {"joined": list(rooms)}
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
UserSharedRoomsServlet(hs).register(http_server)
diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py
index e18f4d01b375..65c37be3e96c 100644
--- a/synapse/rest/client/sync.py
+++ b/synapse/rest/client/sync.py
@@ -14,17 +14,26 @@
import itertools
import logging
from collections import defaultdict
-from typing import TYPE_CHECKING, Any, Callable, Dict, List, Tuple
+from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
from synapse.api.constants import Membership, PresenceState
from synapse.api.errors import Codes, StoreError, SynapseError
from synapse.api.filtering import DEFAULT_FILTER_COLLECTION, FilterCollection
+from synapse.api.presence import UserPresenceState
from synapse.events.utils import (
format_event_for_client_v2_without_room_id,
format_event_raw,
)
from synapse.handlers.presence import format_user_presence_state
-from synapse.handlers.sync import KnockedSyncResult, SyncConfig
+from synapse.handlers.sync import (
+ ArchivedSyncResult,
+ InvitedSyncResult,
+ JoinedSyncResult,
+ KnockedSyncResult,
+ SyncConfig,
+ SyncResult,
+)
+from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string
from synapse.http.site import SynapseRequest
from synapse.types import JsonDict, StreamToken
@@ -192,6 +201,8 @@ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
return 200, {}
time_now = self.clock.time_msec()
+ # We know that the the requester has an access token since appservices
+ # cannot use sync.
response_content = await self.encode_response(
time_now, sync_result, requester.access_token_id, filter_collection
)
@@ -199,7 +210,13 @@ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
logger.debug("Event formatting complete")
return 200, response_content
- async def encode_response(self, time_now, sync_result, access_token_id, filter):
+ async def encode_response(
+ self,
+ time_now: int,
+ sync_result: SyncResult,
+ access_token_id: Optional[int],
+ filter: FilterCollection,
+ ) -> JsonDict:
logger.debug("Formatting events in sync response")
if filter.event_format == "client":
event_formatter = format_event_for_client_v2_without_room_id
@@ -234,7 +251,7 @@ async def encode_response(self, time_now, sync_result, access_token_id, filter):
logger.debug("building sync response dict")
- response: dict = defaultdict(dict)
+ response: JsonDict = defaultdict(dict)
response["next_batch"] = await sync_result.next_batch.to_string(self.store)
if sync_result.account_data:
@@ -274,6 +291,8 @@ async def encode_response(self, time_now, sync_result, access_token_id, filter):
if archived:
response["rooms"][Membership.LEAVE] = archived
+ # By the time we get here groups is no longer optional.
+ assert sync_result.groups is not None
if sync_result.groups.join:
response["groups"][Membership.JOIN] = sync_result.groups.join
if sync_result.groups.invite:
@@ -284,7 +303,7 @@ async def encode_response(self, time_now, sync_result, access_token_id, filter):
return response
@staticmethod
- def encode_presence(events, time_now):
+ def encode_presence(events: List[UserPresenceState], time_now: int) -> JsonDict:
return {
"events": [
{
@@ -299,25 +318,27 @@ def encode_presence(events, time_now):
}
async def encode_joined(
- self, rooms, time_now, token_id, event_fields, event_formatter
- ):
+ self,
+ rooms: List[JoinedSyncResult],
+ time_now: int,
+ token_id: Optional[int],
+ event_fields: List[str],
+ event_formatter: Callable[[JsonDict], JsonDict],
+ ) -> JsonDict:
"""
Encode the joined rooms in a sync result
Args:
- rooms(list[synapse.handlers.sync.JoinedSyncResult]): list of sync
- results for rooms this user is joined to
- time_now(int): current time - used as a baseline for age
- calculations
- token_id(int): ID of the user's auth token - used for namespacing
+ rooms: list of sync results for rooms this user is joined to
+ time_now: current time - used as a baseline for age calculations
+ token_id: ID of the user's auth token - used for namespacing
of transaction IDs
- event_fields(list): List of event fields to include. If empty,
+ event_fields: List of event fields to include. If empty,
all fields will be returned.
- event_formatter (func[dict]): function to convert from federation format
+ event_formatter: function to convert from federation format
to client format
Returns:
- dict[str, dict[str, object]]: the joined rooms list, in our
- response format
+ The joined rooms list, in our response format
"""
joined = {}
for room in rooms:
@@ -332,23 +353,26 @@ async def encode_joined(
return joined
- async def encode_invited(self, rooms, time_now, token_id, event_formatter):
+ async def encode_invited(
+ self,
+ rooms: List[InvitedSyncResult],
+ time_now: int,
+ token_id: Optional[int],
+ event_formatter: Callable[[JsonDict], JsonDict],
+ ) -> JsonDict:
"""
Encode the invited rooms in a sync result
Args:
- rooms(list[synapse.handlers.sync.InvitedSyncResult]): list of
- sync results for rooms this user is invited to
- time_now(int): current time - used as a baseline for age
- calculations
- token_id(int): ID of the user's auth token - used for namespacing
+ rooms: list of sync results for rooms this user is invited to
+ time_now: current time - used as a baseline for age calculations
+ token_id: ID of the user's auth token - used for namespacing
of transaction IDs
- event_formatter (func[dict]): function to convert from federation format
+ event_formatter: function to convert from federation format
to client format
Returns:
- dict[str, dict[str, object]]: the invited rooms list, in our
- response format
+ The invited rooms list, in our response format
"""
invited = {}
for room in rooms:
@@ -371,7 +395,7 @@ async def encode_knocked(
self,
rooms: List[KnockedSyncResult],
time_now: int,
- token_id: int,
+ token_id: Optional[int],
event_formatter: Callable[[Dict], Dict],
) -> Dict[str, Dict[str, Any]]:
"""
@@ -422,25 +446,26 @@ async def encode_knocked(
return knocked
async def encode_archived(
- self, rooms, time_now, token_id, event_fields, event_formatter
- ):
+ self,
+ rooms: List[ArchivedSyncResult],
+ time_now: int,
+ token_id: Optional[int],
+ event_fields: List[str],
+ event_formatter: Callable[[JsonDict], JsonDict],
+ ) -> JsonDict:
"""
Encode the archived rooms in a sync result
Args:
- rooms (list[synapse.handlers.sync.ArchivedSyncResult]): list of
- sync results for rooms this user is joined to
- time_now(int): current time - used as a baseline for age
- calculations
- token_id(int): ID of the user's auth token - used for namespacing
+ rooms: list of sync results for rooms this user is joined to
+ time_now: current time - used as a baseline for age calculations
+ token_id: ID of the user's auth token - used for namespacing
of transaction IDs
- event_fields(list): List of event fields to include. If empty,
+ event_fields: List of event fields to include. If empty,
all fields will be returned.
- event_formatter (func[dict]): function to convert from federation format
- to client format
+ event_formatter: function to convert from federation format to client format
Returns:
- dict[str, dict[str, object]]: The invited rooms list, in our
- response format
+ The archived rooms list, in our response format
"""
joined = {}
for room in rooms:
@@ -456,23 +481,27 @@ async def encode_archived(
return joined
async def encode_room(
- self, room, time_now, token_id, joined, only_fields, event_formatter
- ):
+ self,
+ room: Union[JoinedSyncResult, ArchivedSyncResult],
+ time_now: int,
+ token_id: Optional[int],
+ joined: bool,
+ only_fields: Optional[List[str]],
+ event_formatter: Callable[[JsonDict], JsonDict],
+ ) -> JsonDict:
"""
Args:
- room (JoinedSyncResult|ArchivedSyncResult): sync result for a
- single room
- time_now (int): current time - used as a baseline for age
- calculations
- token_id (int): ID of the user's auth token - used for namespacing
+ room: sync result for a single room
+ time_now: current time - used as a baseline for age calculations
+ token_id: ID of the user's auth token - used for namespacing
of transaction IDs
- joined (bool): True if the user is joined to this room - will mean
+ joined: True if the user is joined to this room - will mean
we handle ephemeral events
- only_fields(list): Optional. The list of event fields to include.
- event_formatter (func[dict]): function to convert from federation format
+ only_fields: Optional. The list of event fields to include.
+ event_formatter: function to convert from federation format
to client format
Returns:
- dict[str, object]: the room, encoded in our response format
+ The room, encoded in our response format
"""
def serialize(events):
@@ -508,7 +537,7 @@ def serialize(events):
account_data = room.account_data
- result = {
+ result: JsonDict = {
"timeline": {
"events": serialized_timeline,
"prev_batch": await room.timeline.prev_batch.to_string(self.store),
@@ -519,6 +548,7 @@ def serialize(events):
}
if joined:
+ assert isinstance(room, JoinedSyncResult)
ephemeral_events = room.ephemeral
result["ephemeral"] = {"events": ephemeral_events}
result["unread_notifications"] = room.unread_notifications
@@ -528,5 +558,5 @@ def serialize(events):
return result
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
SyncRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/tags.py b/synapse/rest/client/tags.py
index c14f83be1878..c88cb9367c5f 100644
--- a/synapse/rest/client/tags.py
+++ b/synapse/rest/client/tags.py
@@ -13,12 +13,19 @@
# limitations under the License.
import logging
+from typing import TYPE_CHECKING, Tuple
from synapse.api.errors import AuthError
+from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_json_object_from_request
+from synapse.http.site import SynapseRequest
+from synapse.types import JsonDict
from ._base import client_patterns
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
@@ -29,12 +36,14 @@ class TagListServlet(RestServlet):
PATTERNS = client_patterns("/user/(?P[^/]*)/rooms/(?P[^/]*)/tags")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
- async def on_GET(self, request, user_id, room_id):
+ async def on_GET(
+ self, request: SynapseRequest, user_id: str, room_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
if user_id != requester.user.to_string():
raise AuthError(403, "Cannot get tags for other users.")
@@ -54,12 +63,14 @@ class TagServlet(RestServlet):
"/user/(?P[^/]*)/rooms/(?P[^/]*)/tags/(?P[^/]*)"
)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.handler = hs.get_account_data_handler()
- async def on_PUT(self, request, user_id, room_id, tag):
+ async def on_PUT(
+ self, request: SynapseRequest, user_id: str, room_id: str, tag: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
if user_id != requester.user.to_string():
raise AuthError(403, "Cannot add tags for other users.")
@@ -70,7 +81,9 @@ async def on_PUT(self, request, user_id, room_id, tag):
return 200, {}
- async def on_DELETE(self, request, user_id, room_id, tag):
+ async def on_DELETE(
+ self, request: SynapseRequest, user_id: str, room_id: str, tag: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
if user_id != requester.user.to_string():
raise AuthError(403, "Cannot add tags for other users.")
@@ -80,6 +93,6 @@ async def on_DELETE(self, request, user_id, room_id, tag):
return 200, {}
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
TagListServlet(hs).register(http_server)
TagServlet(hs).register(http_server)
diff --git a/synapse/rest/client/thirdparty.py b/synapse/rest/client/thirdparty.py
index b5c67c9bb67e..b895c73acf2c 100644
--- a/synapse/rest/client/thirdparty.py
+++ b/synapse/rest/client/thirdparty.py
@@ -12,27 +12,33 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
import logging
+from typing import TYPE_CHECKING, Dict, List, Tuple
from synapse.api.constants import ThirdPartyEntityKind
+from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet
+from synapse.http.site import SynapseRequest
+from synapse.types import JsonDict
from ._base import client_patterns
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
class ThirdPartyProtocolsServlet(RestServlet):
PATTERNS = client_patterns("/thirdparty/protocols")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.appservice_handler = hs.get_application_service_handler()
- async def on_GET(self, request):
+ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
await self.auth.get_user_by_req(request, allow_guest=True)
protocols = await self.appservice_handler.get_3pe_protocols()
@@ -42,13 +48,15 @@ async def on_GET(self, request):
class ThirdPartyProtocolServlet(RestServlet):
PATTERNS = client_patterns("/thirdparty/protocol/(?P[^/]+)$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.appservice_handler = hs.get_application_service_handler()
- async def on_GET(self, request, protocol):
+ async def on_GET(
+ self, request: SynapseRequest, protocol: str
+ ) -> Tuple[int, JsonDict]:
await self.auth.get_user_by_req(request, allow_guest=True)
protocols = await self.appservice_handler.get_3pe_protocols(
@@ -63,16 +71,18 @@ async def on_GET(self, request, protocol):
class ThirdPartyUserServlet(RestServlet):
PATTERNS = client_patterns("/thirdparty/user(/(?P[^/]+))?$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.appservice_handler = hs.get_application_service_handler()
- async def on_GET(self, request, protocol):
+ async def on_GET(
+ self, request: SynapseRequest, protocol: str
+ ) -> Tuple[int, List[JsonDict]]:
await self.auth.get_user_by_req(request, allow_guest=True)
- fields = request.args
+ fields: Dict[bytes, List[bytes]] = request.args # type: ignore[assignment]
fields.pop(b"access_token", None)
results = await self.appservice_handler.query_3pe(
@@ -85,16 +95,18 @@ async def on_GET(self, request, protocol):
class ThirdPartyLocationServlet(RestServlet):
PATTERNS = client_patterns("/thirdparty/location(/(?P[^/]+))?$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.appservice_handler = hs.get_application_service_handler()
- async def on_GET(self, request, protocol):
+ async def on_GET(
+ self, request: SynapseRequest, protocol: str
+ ) -> Tuple[int, List[JsonDict]]:
await self.auth.get_user_by_req(request, allow_guest=True)
- fields = request.args
+ fields: Dict[bytes, List[bytes]] = request.args # type: ignore[assignment]
fields.pop(b"access_token", None)
results = await self.appservice_handler.query_3pe(
@@ -104,7 +116,7 @@ async def on_GET(self, request, protocol):
return 200, results
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
ThirdPartyProtocolsServlet(hs).register(http_server)
ThirdPartyProtocolServlet(hs).register(http_server)
ThirdPartyUserServlet(hs).register(http_server)
diff --git a/synapse/rest/client/tokenrefresh.py b/synapse/rest/client/tokenrefresh.py
index b2f858545cbe..c8c3b25bd36f 100644
--- a/synapse/rest/client/tokenrefresh.py
+++ b/synapse/rest/client/tokenrefresh.py
@@ -12,11 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import TYPE_CHECKING
+
+from twisted.web.server import Request
+
from synapse.api.errors import AuthError
+from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet
from ._base import client_patterns
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
class TokenRefreshRestServlet(RestServlet):
"""
@@ -26,12 +34,12 @@ class TokenRefreshRestServlet(RestServlet):
PATTERNS = client_patterns("/tokenrefresh")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
- async def on_POST(self, request):
+ async def on_POST(self, request: Request) -> None:
raise AuthError(403, "tokenrefresh is no longer supported.")
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
TokenRefreshRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/user_directory.py b/synapse/rest/client/user_directory.py
index 7e8912f0b919..885281111438 100644
--- a/synapse/rest/client/user_directory.py
+++ b/synapse/rest/client/user_directory.py
@@ -13,29 +13,32 @@
# limitations under the License.
import logging
+from typing import TYPE_CHECKING, Tuple
from synapse.api.errors import SynapseError
+from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_json_object_from_request
+from synapse.http.site import SynapseRequest
+from synapse.types import JsonDict
from ._base import client_patterns
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
class UserDirectorySearchRestServlet(RestServlet):
PATTERNS = client_patterns("/user_directory/search$")
- def __init__(self, hs):
- """
- Args:
- hs (synapse.server.HomeServer): server
- """
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self.user_directory_handler = hs.get_user_directory_handler()
- async def on_POST(self, request):
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
"""Searches for users in directory
Returns:
@@ -75,5 +78,5 @@ async def on_POST(self, request):
return 200, results
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
UserDirectorySearchRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py
index fa2e4e9cba48..a1a815cf8256 100644
--- a/synapse/rest/client/versions.py
+++ b/synapse/rest/client/versions.py
@@ -17,9 +17,17 @@
import logging
import re
+from typing import TYPE_CHECKING, Tuple
+
+from twisted.web.server import Request
from synapse.api.constants import RoomCreationPreset
+from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet
+from synapse.types import JsonDict
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@@ -27,7 +35,7 @@
class VersionsRestServlet(RestServlet):
PATTERNS = [re.compile("^/_matrix/client/versions$")]
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.config = hs.config
@@ -45,7 +53,7 @@ def __init__(self, hs):
in self.config.encryption_enabled_by_default_for_room_presets
)
- def on_GET(self, request):
+ def on_GET(self, request: Request) -> Tuple[int, JsonDict]:
return (
200,
{
@@ -89,5 +97,5 @@ def on_GET(self, request):
)
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
VersionsRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/voip.py b/synapse/rest/client/voip.py
index f53020520d37..9d46ed3af3f5 100644
--- a/synapse/rest/client/voip.py
+++ b/synapse/rest/client/voip.py
@@ -15,20 +15,27 @@
import base64
import hashlib
import hmac
+from typing import TYPE_CHECKING, Tuple
+from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet
+from synapse.http.site import SynapseRequest
from synapse.rest.client._base import client_patterns
+from synapse.types import JsonDict
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
class VoipRestServlet(RestServlet):
PATTERNS = client_patterns("/voip/turnServer$", v1=True)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.auth = hs.get_auth()
- async def on_GET(self, request):
+ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(
request, self.hs.config.turn_allow_guests
)
@@ -69,5 +76,5 @@ async def on_GET(self, request):
)
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
VoipRestServlet(hs).register(http_server)
diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py
index a029d426f0b6..12bd745cb21c 100644
--- a/synapse/rest/media/v1/thumbnail_resource.py
+++ b/synapse/rest/media/v1/thumbnail_resource.py
@@ -15,7 +15,7 @@
import logging
-from typing import TYPE_CHECKING, Any, Dict, List, Optional
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
from twisted.web.server import Request
@@ -414,9 +414,9 @@ def _select_thumbnail(
if desired_method == "crop":
# Thumbnails that match equal or larger sizes of desired width/height.
- crop_info_list = []
+ crop_info_list: List[Tuple[int, int, int, bool, int, Dict[str, Any]]] = []
# Other thumbnails.
- crop_info_list2 = []
+ crop_info_list2: List[Tuple[int, int, int, bool, int, Dict[str, Any]]] = []
for info in thumbnail_infos:
# Skip thumbnails generated with different methods.
if info["thumbnail_method"] != "crop":
@@ -451,15 +451,19 @@ def _select_thumbnail(
info,
)
)
+ # Pick the most appropriate thumbnail. Some values of `desired_width` and
+ # `desired_height` may result in a tie, in which case we avoid comparing on
+ # the thumbnail info dictionary and pick the thumbnail that appears earlier
+ # in the list of candidates.
if crop_info_list:
- thumbnail_info = min(crop_info_list)[-1]
+ thumbnail_info = min(crop_info_list, key=lambda t: t[:-1])[-1]
elif crop_info_list2:
- thumbnail_info = min(crop_info_list2)[-1]
+ thumbnail_info = min(crop_info_list2, key=lambda t: t[:-1])[-1]
elif desired_method == "scale":
# Thumbnails that match equal or larger sizes of desired width/height.
- info_list = []
+ info_list: List[Tuple[int, bool, int, Dict[str, Any]]] = []
# Other thumbnails.
- info_list2 = []
+ info_list2: List[Tuple[int, bool, int, Dict[str, Any]]] = []
for info in thumbnail_infos:
# Skip thumbnails generated with different methods.
@@ -477,10 +481,14 @@ def _select_thumbnail(
info_list2.append(
(size_quality, type_quality, length_quality, info)
)
+ # Pick the most appropriate thumbnail. Some values of `desired_width` and
+ # `desired_height` may result in a tie, in which case we avoid comparing on
+ # the thumbnail info dictionary and pick the thumbnail that appears earlier
+ # in the list of candidates.
if info_list:
- thumbnail_info = min(info_list)[-1]
+ thumbnail_info = min(info_list, key=lambda t: t[:-1])[-1]
elif info_list2:
- thumbnail_info = min(info_list2)[-1]
+ thumbnail_info = min(info_list2, key=lambda t: t[:-1])[-1]
if thumbnail_info:
return FileInfo(
diff --git a/synapse/server.py b/synapse/server.py
index de6517663e6b..5adeeff61a5f 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -76,6 +76,7 @@
from synapse.handlers.event_auth import EventAuthHandler
from synapse.handlers.events import EventHandler, EventStreamHandler
from synapse.handlers.federation import FederationHandler
+from synapse.handlers.federation_event import FederationEventHandler
from synapse.handlers.groups_local import GroupsLocalHandler, GroupsLocalWorkerHandler
from synapse.handlers.identity import IdentityHandler
from synapse.handlers.initial_sync import InitialSyncHandler
@@ -546,6 +547,10 @@ def get_event_stream_handler(self) -> EventStreamHandler:
def get_federation_handler(self) -> FederationHandler:
return FederationHandler(self)
+ @cache_in_self
+ def get_federation_event_handler(self) -> FederationEventHandler:
+ return FederationEventHandler(self)
+
@cache_in_self
def get_identity_handler(self) -> IdentityHandler:
return IdentityHandler(self)
diff --git a/synapse/server_notices/server_notices_manager.py b/synapse/server_notices/server_notices_manager.py
index f19075b76050..d87a53891740 100644
--- a/synapse/server_notices/server_notices_manager.py
+++ b/synapse/server_notices/server_notices_manager.py
@@ -12,26 +12,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import Optional
+from typing import TYPE_CHECKING, Optional
from synapse.api.constants import EventTypes, Membership, RoomCreationPreset
from synapse.events import EventBase
from synapse.types import UserID, create_requester
from synapse.util.caches.descriptors import cached
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
SERVER_NOTICE_ROOM_TAG = "m.server_notice"
class ServerNoticesManager:
- def __init__(self, hs):
- """
-
- Args:
- hs (synapse.server.HomeServer):
- """
-
+ def __init__(self, hs: "HomeServer"):
self._store = hs.get_datastore()
self._config = hs.config
self._account_data_handler = hs.get_account_data_handler()
@@ -58,6 +55,7 @@ async def send_notice(
event_content: dict,
type: str = EventTypes.Message,
state_key: Optional[str] = None,
+ txn_id: Optional[str] = None,
) -> EventBase:
"""Send a notice to the given user
@@ -68,6 +66,7 @@ async def send_notice(
event_content: content of event to send
type: type of event
is_state_event: Is the event a state event
+ txn_id: The transaction ID.
"""
room_id = await self.get_or_create_notice_room_for_user(user_id)
await self.maybe_invite_user_to_room(user_id, room_id)
@@ -90,7 +89,7 @@ async def send_notice(
event_dict["state_key"] = state_key
event, _ = await self._event_creation_handler.create_and_send_nonmember_event(
- requester, event_dict, ratelimit=False
+ requester, event_dict, ratelimit=False, txn_id=txn_id
)
return event
diff --git a/synapse/static/client/register/style.css b/synapse/static/client/register/style.css
index 5a7b6eebf2f0..8a39b5d0f576 100644
--- a/synapse/static/client/register/style.css
+++ b/synapse/static/client/register/style.css
@@ -57,4 +57,8 @@ textarea, input {
background-color: #f8f8f8;
border: 1px #ccc solid;
-}
\ No newline at end of file
+}
+
+.error {
+ color: red;
+}
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 0d0d10a7181d..9819f587e030 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -63,6 +63,7 @@
from .room import RoomStore
from .roommember import RoomMemberStore
from .search import SearchStore
+from .session import SessionStore
from .signatures import SignatureStore
from .state import StateStore
from .stats import StatsStore
@@ -121,6 +122,7 @@ class DataStore(
ServerMetricsStore,
EventForwardExtremitiesStore,
LockStore,
+ SessionStore,
):
def __init__(self, database: DatabasePool, db_conn, hs):
self.hs = hs
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 375463e4e979..9501f00f3bb3 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -520,16 +520,26 @@ async def _get_events_from_cache_or_db(
# We now look up if we're already fetching some of the events in the DB,
# if so we wait for those lookups to finish instead of pulling the same
# events out of the DB multiple times.
- already_fetching: Dict[str, defer.Deferred] = {}
+ #
+ # Note: we might get the same `ObservableDeferred` back for multiple
+ # events we're already fetching, so we deduplicate the deferreds to
+ # avoid extraneous work (if we don't do this we can end up in a n^2 mode
+ # when we wait on the same Deferred N times, then try and merge the
+ # same dict into itself N times).
+ already_fetching_ids: Set[str] = set()
+ already_fetching_deferreds: Set[
+ ObservableDeferred[Dict[str, _EventCacheEntry]]
+ ] = set()
for event_id in missing_events_ids:
deferred = self._current_event_fetches.get(event_id)
if deferred is not None:
# We're already pulling the event out of the DB. Add the deferred
# to the collection of deferreds to wait on.
- already_fetching[event_id] = deferred.observe()
+ already_fetching_ids.add(event_id)
+ already_fetching_deferreds.add(deferred)
- missing_events_ids.difference_update(already_fetching)
+ missing_events_ids.difference_update(already_fetching_ids)
if missing_events_ids:
log_ctx = current_context()
@@ -569,18 +579,25 @@ async def _get_events_from_cache_or_db(
with PreserveLoggingContext():
fetching_deferred.callback(missing_events)
- if already_fetching:
+ if already_fetching_deferreds:
# Wait for the other event requests to finish and add their results
# to ours.
results = await make_deferred_yieldable(
defer.gatherResults(
- already_fetching.values(),
+ (d.observe() for d in already_fetching_deferreds),
consumeErrors=True,
)
).addErrback(unwrapFirstError)
for result in results:
- event_entry_map.update(result)
+ # We filter out events that we haven't asked for as we might get
+ # a *lot* of superfluous events back, and there is no point
+ # going through and inserting them all (which can take time).
+ event_entry_map.update(
+ (event_id, entry)
+ for event_id, entry in result.items()
+ if event_id in already_fetching_ids
+ )
if not allow_rejected:
event_entry_map = {
diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py
index 664c65dac5a6..bccff5e5b95c 100644
--- a/synapse/storage/databases/main/purge_events.py
+++ b/synapse/storage/databases/main/purge_events.py
@@ -295,6 +295,7 @@ def _purge_history_txn(
self._invalidate_cache_and_stream(
txn, self.have_seen_event, (room_id, event_id)
)
+ self._invalidate_get_event_cache(event_id)
logger.info("[purge] done")
diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index b48fe086d4cc..63ac09c61dad 100644
--- a/synapse/storage/databases/main/pusher.py
+++ b/synapse/storage/databases/main/pusher.py
@@ -48,6 +48,11 @@ def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"
self._remove_stale_pushers,
)
+ self.db_pool.updates.register_background_update_handler(
+ "remove_deleted_email_pushers",
+ self._remove_deleted_email_pushers,
+ )
+
def _decode_pushers_rows(self, rows: Iterable[dict]) -> Iterator[PusherConfig]:
"""JSON-decode the data in the rows returned from the `pushers` table
@@ -388,6 +393,74 @@ def _delete_pushers(txn) -> int:
return number_deleted
+ async def _remove_deleted_email_pushers(
+ self, progress: dict, batch_size: int
+ ) -> int:
+ """A background update that deletes all pushers for deleted email addresses.
+
+ In previous versions of synapse, when users deleted their email address, it didn't
+ also delete all the pushers for that email address. This background update removes
+ those to prevent unwanted emails. This should only need to be run once (when users
+ upgrade to v1.42.0
+
+ Args:
+ progress: dict used to store progress of this background update
+ batch_size: the maximum number of rows to retrieve in a single select query
+
+ Returns:
+ The number of deleted rows
+ """
+
+ last_pusher = progress.get("last_pusher", 0)
+
+ def _delete_pushers(txn) -> int:
+
+ sql = """
+ SELECT p.id, p.user_name, p.app_id, p.pushkey
+ FROM pushers AS p
+ LEFT JOIN user_threepids AS t
+ ON t.user_id = p.user_name
+ AND t.medium = 'email'
+ AND t.address = p.pushkey
+ WHERE t.user_id is NULL
+ AND p.app_id = 'm.email'
+ AND p.id > ?
+ ORDER BY p.id ASC
+ LIMIT ?
+ """
+
+ txn.execute(sql, (last_pusher, batch_size))
+ rows = txn.fetchall()
+
+ last = None
+ num_deleted = 0
+ for row in rows:
+ last = row[0]
+ num_deleted += 1
+ self.db_pool.simple_delete_txn(
+ txn,
+ "pushers",
+ {"user_name": row[1], "app_id": row[2], "pushkey": row[3]},
+ )
+
+ if last is not None:
+ self.db_pool.updates._background_update_progress_txn(
+ txn, "remove_deleted_email_pushers", {"last_pusher": last}
+ )
+
+ return num_deleted
+
+ number_deleted = await self.db_pool.runInteraction(
+ "_remove_deleted_email_pushers", _delete_pushers
+ )
+
+ if number_deleted < batch_size:
+ await self.db_pool.updates._end_background_update(
+ "remove_deleted_email_pushers"
+ )
+
+ return number_deleted
+
class PusherStore(PusherWorkerStore):
def get_pushers_stream_token(self) -> int:
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index c67bea81c6b9..a6517962f619 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -754,16 +754,18 @@ async def get_user_id_by_threepid(self, medium: str, address: str) -> Optional[s
)
return user_id
- def get_user_id_by_threepid_txn(self, txn, medium, address):
+ def get_user_id_by_threepid_txn(
+ self, txn, medium: str, address: str
+ ) -> Optional[str]:
"""Returns user id from threepid
Args:
txn (cursor):
- medium (str): threepid medium e.g. email
- address (str): threepid address e.g. me@example.com
+ medium: threepid medium e.g. email
+ address: threepid address e.g. me@example.com
Returns:
- str|None: user id or None if no user id/threepid mapping exists
+ user id, or None if no user id/threepid mapping exists
"""
ret = self.db_pool.simple_select_one_txn(
txn,
@@ -776,14 +778,21 @@ def get_user_id_by_threepid_txn(self, txn, medium, address):
return ret["user_id"]
return None
- async def user_add_threepid(self, user_id, medium, address, validated_at, added_at):
+ async def user_add_threepid(
+ self,
+ user_id: str,
+ medium: str,
+ address: str,
+ validated_at: int,
+ added_at: int,
+ ) -> None:
await self.db_pool.simple_upsert(
"user_threepids",
{"medium": medium, "address": address},
{"user_id": user_id, "validated_at": validated_at, "added_at": added_at},
)
- async def user_get_threepids(self, user_id):
+ async def user_get_threepids(self, user_id) -> List[Dict[str, Any]]:
return await self.db_pool.simple_select_list(
"user_threepids",
{"user_id": user_id},
@@ -791,7 +800,9 @@ async def user_get_threepids(self, user_id):
"user_get_threepids",
)
- async def user_delete_threepid(self, user_id, medium, address) -> None:
+ async def user_delete_threepid(
+ self, user_id: str, medium: str, address: str
+ ) -> None:
await self.db_pool.simple_delete(
"user_threepids",
keyvalues={"user_id": user_id, "medium": medium, "address": address},
@@ -1157,6 +1168,322 @@ async def update_access_token_last_validated(self, token_id: int) -> None:
desc="update_access_token_last_validated",
)
+ async def registration_token_is_valid(self, token: str) -> bool:
+ """Checks if a token can be used to authenticate a registration.
+
+ Args:
+ token: The registration token to be checked
+ Returns:
+ True if the token is valid, False otherwise.
+ """
+ res = await self.db_pool.simple_select_one(
+ "registration_tokens",
+ keyvalues={"token": token},
+ retcols=["uses_allowed", "pending", "completed", "expiry_time"],
+ allow_none=True,
+ )
+
+ # Check if the token exists
+ if res is None:
+ return False
+
+ # Check if the token has expired
+ now = self._clock.time_msec()
+ if res["expiry_time"] and res["expiry_time"] < now:
+ return False
+
+ # Check if the token has been used up
+ if (
+ res["uses_allowed"]
+ and res["pending"] + res["completed"] >= res["uses_allowed"]
+ ):
+ return False
+
+ # Otherwise, the token is valid
+ return True
+
+ async def set_registration_token_pending(self, token: str) -> None:
+ """Increment the pending registrations counter for a token.
+
+ Args:
+ token: The registration token pending use
+ """
+
+ def _set_registration_token_pending_txn(txn):
+ pending = self.db_pool.simple_select_one_onecol_txn(
+ txn,
+ "registration_tokens",
+ keyvalues={"token": token},
+ retcol="pending",
+ )
+ self.db_pool.simple_update_one_txn(
+ txn,
+ "registration_tokens",
+ keyvalues={"token": token},
+ updatevalues={"pending": pending + 1},
+ )
+
+ return await self.db_pool.runInteraction(
+ "set_registration_token_pending", _set_registration_token_pending_txn
+ )
+
+ async def use_registration_token(self, token: str) -> None:
+ """Complete a use of the given registration token.
+
+ The `pending` counter will be decremented, and the `completed`
+ counter will be incremented.
+
+ Args:
+ token: The registration token to be 'used'
+ """
+
+ def _use_registration_token_txn(txn):
+ # Normally, res is Optional[Dict[str, Any]].
+ # Override type because the return type is only optional if
+ # allow_none is True, and we don't want mypy throwing errors
+ # about None not being indexable.
+ res: Dict[str, Any] = self.db_pool.simple_select_one_txn(
+ txn,
+ "registration_tokens",
+ keyvalues={"token": token},
+ retcols=["pending", "completed"],
+ ) # type: ignore
+
+ # Decrement pending and increment completed
+ self.db_pool.simple_update_one_txn(
+ txn,
+ "registration_tokens",
+ keyvalues={"token": token},
+ updatevalues={
+ "completed": res["completed"] + 1,
+ "pending": res["pending"] - 1,
+ },
+ )
+
+ return await self.db_pool.runInteraction(
+ "use_registration_token", _use_registration_token_txn
+ )
+
+ async def get_registration_tokens(
+ self, valid: Optional[bool] = None
+ ) -> List[Dict[str, Any]]:
+ """List all registration tokens. Used by the admin API.
+
+ Args:
+ valid: If True, only valid tokens are returned.
+ If False, only invalid tokens are returned.
+ Default is None: return all tokens regardless of validity.
+
+ Returns:
+ A list of dicts, each containing details of a token.
+ """
+
+ def select_registration_tokens_txn(txn, now: int, valid: Optional[bool]):
+ if valid is None:
+ # Return all tokens regardless of validity
+ txn.execute("SELECT * FROM registration_tokens")
+
+ elif valid:
+ # Select valid tokens only
+ sql = (
+ "SELECT * FROM registration_tokens WHERE "
+ "(uses_allowed > pending + completed OR uses_allowed IS NULL) "
+ "AND (expiry_time > ? OR expiry_time IS NULL)"
+ )
+ txn.execute(sql, [now])
+
+ else:
+ # Select invalid tokens only
+ sql = (
+ "SELECT * FROM registration_tokens WHERE "
+ "uses_allowed <= pending + completed OR expiry_time <= ?"
+ )
+ txn.execute(sql, [now])
+
+ return self.db_pool.cursor_to_dict(txn)
+
+ return await self.db_pool.runInteraction(
+ "select_registration_tokens",
+ select_registration_tokens_txn,
+ self._clock.time_msec(),
+ valid,
+ )
+
+ async def get_one_registration_token(self, token: str) -> Optional[Dict[str, Any]]:
+ """Get info about the given registration token. Used by the admin API.
+
+ Args:
+ token: The token to retrieve information about.
+
+ Returns:
+ A dict, or None if token doesn't exist.
+ """
+ return await self.db_pool.simple_select_one(
+ "registration_tokens",
+ keyvalues={"token": token},
+ retcols=["token", "uses_allowed", "pending", "completed", "expiry_time"],
+ allow_none=True,
+ desc="get_one_registration_token",
+ )
+
+ async def generate_registration_token(
+ self, length: int, chars: str
+ ) -> Optional[str]:
+ """Generate a random registration token. Used by the admin API.
+
+ Args:
+ length: The length of the token to generate.
+ chars: A string of the characters allowed in the generated token.
+
+ Returns:
+ The generated token.
+
+ Raises:
+ SynapseError if a unique registration token could still not be
+ generated after a few tries.
+ """
+ # Make a few attempts at generating a unique token of the required
+ # length before failing.
+ for _i in range(3):
+ # Generate token
+ token = "".join(random.choices(chars, k=length))
+
+ # Check if the token already exists
+ existing_token = await self.db_pool.simple_select_one_onecol(
+ "registration_tokens",
+ keyvalues={"token": token},
+ retcol="token",
+ allow_none=True,
+ desc="check_if_registration_token_exists",
+ )
+
+ if existing_token is None:
+ # The generated token doesn't exist yet, return it
+ return token
+
+ raise SynapseError(
+ 500,
+ "Unable to generate a unique registration token. Try again with a greater length",
+ Codes.UNKNOWN,
+ )
+
+ async def create_registration_token(
+ self, token: str, uses_allowed: Optional[int], expiry_time: Optional[int]
+ ) -> bool:
+ """Create a new registration token. Used by the admin API.
+
+ Args:
+ token: The token to create.
+ uses_allowed: The number of times the token can be used to complete
+ a registration before it becomes invalid. A value of None indicates
+ unlimited uses.
+ expiry_time: The latest time the token is valid. Given as the
+ number of milliseconds since 1970-01-01 00:00:00 UTC. A value of
+ None indicates that the token does not expire.
+
+ Returns:
+ Whether the row was inserted or not.
+ """
+
+ def _create_registration_token_txn(txn):
+ row = self.db_pool.simple_select_one_txn(
+ txn,
+ "registration_tokens",
+ keyvalues={"token": token},
+ retcols=["token"],
+ allow_none=True,
+ )
+
+ if row is not None:
+ # Token already exists
+ return False
+
+ self.db_pool.simple_insert_txn(
+ txn,
+ "registration_tokens",
+ values={
+ "token": token,
+ "uses_allowed": uses_allowed,
+ "pending": 0,
+ "completed": 0,
+ "expiry_time": expiry_time,
+ },
+ )
+
+ return True
+
+ return await self.db_pool.runInteraction(
+ "create_registration_token", _create_registration_token_txn
+ )
+
+ async def update_registration_token(
+ self, token: str, updatevalues: Dict[str, Optional[int]]
+ ) -> Optional[Dict[str, Any]]:
+ """Update a registration token. Used by the admin API.
+
+ Args:
+ token: The token to update.
+ updatevalues: A dict with the fields to update. E.g.:
+ `{"uses_allowed": 3}` to update just uses_allowed, or
+ `{"uses_allowed": 3, "expiry_time": None}` to update both.
+ This is passed straight to simple_update_one.
+
+ Returns:
+ A dict with all info about the token, or None if token doesn't exist.
+ """
+
+ def _update_registration_token_txn(txn):
+ try:
+ self.db_pool.simple_update_one_txn(
+ txn,
+ "registration_tokens",
+ keyvalues={"token": token},
+ updatevalues=updatevalues,
+ )
+ except StoreError:
+ # Update failed because token does not exist
+ return None
+
+ # Get all info about the token so it can be sent in the response
+ return self.db_pool.simple_select_one_txn(
+ txn,
+ "registration_tokens",
+ keyvalues={"token": token},
+ retcols=[
+ "token",
+ "uses_allowed",
+ "pending",
+ "completed",
+ "expiry_time",
+ ],
+ allow_none=True,
+ )
+
+ return await self.db_pool.runInteraction(
+ "update_registration_token", _update_registration_token_txn
+ )
+
+ async def delete_registration_token(self, token: str) -> bool:
+ """Delete a registration token. Used by the admin API.
+
+ Args:
+ token: The token to delete.
+
+ Returns:
+ Whether the token was successfully deleted or not.
+ """
+ try:
+ await self.db_pool.simple_delete_one(
+ "registration_tokens",
+ keyvalues={"token": token},
+ desc="delete_registration_token",
+ )
+ except StoreError:
+ # Deletion failed because token does not exist
+ return False
+
+ return True
+
@cached()
async def mark_access_token_as_used(self, token_id: int) -> None:
"""
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index e8157ba3d4eb..c58a4b869072 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -307,7 +307,9 @@ def _get_room_summary_txn(txn):
)
@cached()
- async def get_invited_rooms_for_local_user(self, user_id: str) -> RoomsForUser:
+ async def get_invited_rooms_for_local_user(
+ self, user_id: str
+ ) -> List[RoomsForUser]:
"""Get all the rooms the *local* user is invited to.
Args:
@@ -384,9 +386,10 @@ def _get_rooms_for_local_user_where_membership_is_txn(
)
sql = """
- SELECT room_id, e.sender, c.membership, event_id, e.stream_ordering
+ SELECT room_id, e.sender, c.membership, event_id, e.stream_ordering, r.room_version
FROM local_current_membership AS c
INNER JOIN events AS e USING (room_id, event_id)
+ INNER JOIN rooms AS r USING (room_id)
WHERE
user_id = ?
AND %s
@@ -395,7 +398,7 @@ def _get_rooms_for_local_user_where_membership_is_txn(
)
txn.execute(sql, (user_id, *args))
- results = [RoomsForUser(**r) for r in self.db_pool.cursor_to_dict(txn)]
+ results = [RoomsForUser(*r) for r in txn]
return results
@@ -445,7 +448,8 @@ async def get_rooms_for_user_with_stream_ordering(
Returns:
Returns the rooms the user is in currently, along with the stream
- ordering of the most recent join for that user and room.
+ ordering of the most recent join for that user and room, along with
+ the room version of the room.
"""
return await self.db_pool.runInteraction(
"get_rooms_for_user_with_stream_ordering",
@@ -522,7 +526,9 @@ def _get_users_server_still_shares_room_with_txn(txn):
_get_users_server_still_shares_room_with_txn,
)
- async def get_rooms_for_user(self, user_id: str, on_invalidate=None):
+ async def get_rooms_for_user(
+ self, user_id: str, on_invalidate=None
+ ) -> FrozenSet[str]:
"""Returns a set of room_ids the user is currently joined to.
If a remote user only returns rooms this server is currently
diff --git a/synapse/storage/databases/main/session.py b/synapse/storage/databases/main/session.py
new file mode 100644
index 000000000000..172f27d109ad
--- /dev/null
+++ b/synapse/storage/databases/main/session.py
@@ -0,0 +1,145 @@
+# -*- coding: utf-8 -*-
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+import synapse.util.stringutils as stringutils
+from synapse.api.errors import StoreError
+from synapse.metrics.background_process_metrics import wrap_as_background_process
+from synapse.storage._base import SQLBaseStore, db_to_json
+from synapse.storage.database import (
+ DatabasePool,
+ LoggingDatabaseConnection,
+ LoggingTransaction,
+)
+from synapse.types import JsonDict
+from synapse.util import json_encoder
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
+
+class SessionStore(SQLBaseStore):
+ """
+ A store for generic session data.
+
+ Each type of session should provide a unique type (to separate sessions).
+
+ Sessions are automatically removed when they expire.
+ """
+
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
+ super().__init__(database, db_conn, hs)
+
+ # Create a background job for culling expired sessions.
+ if hs.config.run_background_tasks:
+ self._clock.looping_call(self._delete_expired_sessions, 30 * 60 * 1000)
+
+ async def create_session(
+ self, session_type: str, value: JsonDict, expiry_ms: int
+ ) -> str:
+ """
+ Creates a new pagination session for the room hierarchy endpoint.
+
+ Args:
+ session_type: The type for this session.
+ value: The value to store.
+ expiry_ms: How long before an item is evicted from the cache
+ in milliseconds. Default is 0, indicating items never get
+ evicted based on time.
+
+ Returns:
+ The newly created session ID.
+
+ Raises:
+ StoreError if a unique session ID cannot be generated.
+ """
+ # autogen a session ID and try to create it. We may clash, so just
+ # try a few times till one goes through, giving up eventually.
+ attempts = 0
+ while attempts < 5:
+ session_id = stringutils.random_string(24)
+
+ try:
+ await self.db_pool.simple_insert(
+ table="sessions",
+ values={
+ "session_id": session_id,
+ "session_type": session_type,
+ "value": json_encoder.encode(value),
+ "expiry_time_ms": self.hs.get_clock().time_msec() + expiry_ms,
+ },
+ desc="create_session",
+ )
+
+ return session_id
+ except self.db_pool.engine.module.IntegrityError:
+ attempts += 1
+ raise StoreError(500, "Couldn't generate a session ID.")
+
+ async def get_session(self, session_type: str, session_id: str) -> JsonDict:
+ """
+ Retrieve data stored with create_session
+
+ Args:
+ session_type: The type for this session.
+ session_id: The session ID returned from create_session.
+
+ Raises:
+ StoreError if the session cannot be found.
+ """
+
+ def _get_session(
+ txn: LoggingTransaction, session_type: str, session_id: str, ts: int
+ ) -> JsonDict:
+ # This includes the expiry time since items are only periodically
+ # deleted, not upon expiry.
+ select_sql = """
+ SELECT value FROM sessions WHERE
+ session_type = ? AND session_id = ? AND expiry_time_ms > ?
+ """
+ txn.execute(select_sql, [session_type, session_id, ts])
+ row = txn.fetchone()
+
+ if not row:
+ raise StoreError(404, "No session")
+
+ return db_to_json(row[0])
+
+ return await self.db_pool.runInteraction(
+ "get_session",
+ _get_session,
+ session_type,
+ session_id,
+ self._clock.time_msec(),
+ )
+
+ @wrap_as_background_process("delete_expired_sessions")
+ async def _delete_expired_sessions(self) -> None:
+ """Remove sessions with expiry dates that have passed."""
+
+ def _delete_expired_sessions_txn(txn: LoggingTransaction, ts: int) -> None:
+ sql = "DELETE FROM sessions WHERE expiry_time_ms <= ?"
+ txn.execute(sql, (ts,))
+
+ await self.db_pool.runInteraction(
+ "delete_expired_sessions",
+ _delete_expired_sessions_txn,
+ self._clock.time_msec(),
+ )
diff --git a/synapse/storage/databases/main/ui_auth.py b/synapse/storage/databases/main/ui_auth.py
index 38bfdf5dad3a..4d6bbc94c774 100644
--- a/synapse/storage/databases/main/ui_auth.py
+++ b/synapse/storage/databases/main/ui_auth.py
@@ -15,6 +15,7 @@
import attr
+from synapse.api.constants import LoginType
from synapse.api.errors import StoreError
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import LoggingTransaction
@@ -329,6 +330,48 @@ def _delete_old_ui_auth_sessions_txn(
keyvalues={},
)
+ # If a registration token was used, decrement the pending counter
+ # before deleting the session.
+ rows = self.db_pool.simple_select_many_txn(
+ txn,
+ table="ui_auth_sessions_credentials",
+ column="session_id",
+ iterable=session_ids,
+ keyvalues={"stage_type": LoginType.REGISTRATION_TOKEN},
+ retcols=["result"],
+ )
+
+ # Get the tokens used and how much pending needs to be decremented by.
+ token_counts: Dict[str, int] = {}
+ for r in rows:
+ # If registration was successfully completed, the result of the
+ # registration token stage for that session will be True.
+ # If a token was used to authenticate, but registration was
+ # never completed, the result will be the token used.
+ token = db_to_json(r["result"])
+ if isinstance(token, str):
+ token_counts[token] = token_counts.get(token, 0) + 1
+
+ # Update the `pending` counters.
+ if len(token_counts) > 0:
+ token_rows = self.db_pool.simple_select_many_txn(
+ txn,
+ table="registration_tokens",
+ column="token",
+ iterable=list(token_counts.keys()),
+ keyvalues={},
+ retcols=["token", "pending"],
+ )
+ for token_row in token_rows:
+ token = token_row["token"]
+ new_pending = token_row["pending"] - token_counts[token]
+ self.db_pool.simple_update_one_txn(
+ txn,
+ table="registration_tokens",
+ keyvalues={"token": token},
+ updatevalues={"pending": new_pending},
+ )
+
# Delete the corresponding completed credentials.
self.db_pool.simple_delete_many_txn(
txn,
diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index 1101c0a58db8..ec0c72473024 100644
--- a/synapse/storage/databases/main/user_directory.py
+++ b/synapse/storage/databases/main/user_directory.py
@@ -365,7 +365,7 @@ async def is_room_world_readable_or_publicly_joinable(self, room_id):
return False
async def update_profile_in_user_dir(
- self, user_id: str, display_name: str, avatar_url: str
+ self, user_id: str, display_name: Optional[str], avatar_url: Optional[str]
) -> None:
"""
Update or add a user's profile in the user directory.
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index c34fbf21bc42..2500381b7b85 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -14,25 +14,41 @@
# limitations under the License.
import logging
-from collections import namedtuple
+from typing import List, Optional, Tuple
+
+import attr
+
+from synapse.types import PersistedEventPosition
logger = logging.getLogger(__name__)
-RoomsForUser = namedtuple(
- "RoomsForUser", ("room_id", "sender", "membership", "event_id", "stream_ordering")
-)
+@attr.s(slots=True, frozen=True, weakref_slot=False, auto_attribs=True)
+class RoomsForUser:
+ room_id: str
+ sender: str
+ membership: str
+ event_id: str
+ stream_ordering: int
+ room_version_id: str
+
+
+@attr.s(slots=True, frozen=True, weakref_slot=False, auto_attribs=True)
+class GetRoomsForUserWithStreamOrdering:
+ room_id: str
+ event_pos: PersistedEventPosition
-GetRoomsForUserWithStreamOrdering = namedtuple(
- "GetRoomsForUserWithStreamOrdering", ("room_id", "event_pos")
-)
+@attr.s(slots=True, frozen=True, weakref_slot=False, auto_attribs=True)
+class ProfileInfo:
+ avatar_url: Optional[str]
+ display_name: Optional[str]
-# We store this using a namedtuple so that we save about 3x space over using a
-# dict.
-ProfileInfo = namedtuple("ProfileInfo", ("avatar_url", "display_name"))
-# "members" points to a truncated list of (user_id, event_id) tuples for users of
-# a given membership type, suitable for use in calculating heroes for a room.
-# "count" points to the total numberr of users of a given membership type.
-MemberSummary = namedtuple("MemberSummary", ("members", "count"))
+@attr.s(slots=True, frozen=True, weakref_slot=False, auto_attribs=True)
+class MemberSummary:
+ # A truncated list of (user_id, event_id) tuples for users of a given
+ # membership type, suitable for use in calculating heroes for a room.
+ members: List[Tuple[str, str]]
+ # The total number of users of a given membership type.
+ count: int
diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py
index a5bc0ee8a560..af9cc69949c3 100644
--- a/synapse/storage/schema/__init__.py
+++ b/synapse/storage/schema/__init__.py
@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+# When updating these values, please leave a short summary of the changes below.
+
SCHEMA_VERSION = 63
"""Represents the expectations made by the codebase about the database schema
diff --git a/synapse/storage/schema/main/delta/63/01create_registration_tokens.sql b/synapse/storage/schema/main/delta/63/01create_registration_tokens.sql
new file mode 100644
index 000000000000..ee6cf958f4f3
--- /dev/null
+++ b/synapse/storage/schema/main/delta/63/01create_registration_tokens.sql
@@ -0,0 +1,23 @@
+/* Copyright 2021 Callum Brown
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE IF NOT EXISTS registration_tokens(
+ token TEXT NOT NULL, -- The token that can be used for authentication.
+ uses_allowed INT, -- The total number of times this token can be used. NULL if no limit.
+ pending INT NOT NULL, -- The number of in progress registrations using this token.
+ completed INT NOT NULL, -- The number of times this token has been used to complete a registration.
+ expiry_time BIGINT, -- The latest time this token will be valid (epoch time in milliseconds). NULL if token doesn't expire.
+ UNIQUE (token)
+);
diff --git a/synapse/storage/schema/main/delta/63/02delete_unlinked_email_pushers.sql b/synapse/storage/schema/main/delta/63/02delete_unlinked_email_pushers.sql
new file mode 100644
index 000000000000..611c4b95cf15
--- /dev/null
+++ b/synapse/storage/schema/main/delta/63/02delete_unlinked_email_pushers.sql
@@ -0,0 +1,20 @@
+/* Copyright 2021 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+-- We may not have deleted all pushers for emails that are no longer linked
+-- to an account, so we set up a background job to delete them.
+INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
+ (6302, 'remove_deleted_email_pushers', '{}');
diff --git a/synapse/storage/schema/main/delta/63/03session_store.sql b/synapse/storage/schema/main/delta/63/03session_store.sql
new file mode 100644
index 000000000000..535fb34c109c
--- /dev/null
+++ b/synapse/storage/schema/main/delta/63/03session_store.sql
@@ -0,0 +1,23 @@
+/*
+ * Copyright 2021 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE IF NOT EXISTS sessions(
+ session_type TEXT NOT NULL, -- The unique key for this type of session.
+ session_id TEXT NOT NULL, -- The session ID passed to the client.
+ value TEXT NOT NULL, -- A JSON dictionary to persist.
+ expiry_time_ms BIGINT NOT NULL, -- The time this session will expire (epoch time in milliseconds).
+ UNIQUE (session_type, session_id)
+);
diff --git a/tests/events/test_presence_router.py b/tests/events/test_presence_router.py
index 6b87f571b8dc..3b3866bff8fb 100644
--- a/tests/events/test_presence_router.py
+++ b/tests/events/test_presence_router.py
@@ -17,7 +17,7 @@
import attr
from synapse.api.constants import EduTypes
-from synapse.events.presence_router import PresenceRouter
+from synapse.events.presence_router import PresenceRouter, load_legacy_presence_router
from synapse.federation.units import Transaction
from synapse.handlers.presence import UserPresenceState
from synapse.module_api import ModuleApi
@@ -34,7 +34,7 @@ class PresenceRouterTestConfig:
users_who_should_receive_all_presence = attr.ib(type=List[str], default=[])
-class PresenceRouterTestModule:
+class LegacyPresenceRouterTestModule:
def __init__(self, config: PresenceRouterTestConfig, module_api: ModuleApi):
self._config = config
self._module_api = module_api
@@ -77,6 +77,53 @@ def parse_config(config_dict: dict) -> PresenceRouterTestConfig:
return config
+class PresenceRouterTestModule:
+ def __init__(self, config: PresenceRouterTestConfig, api: ModuleApi):
+ self._config = config
+ self._module_api = api
+ api.register_presence_router_callbacks(
+ get_users_for_states=self.get_users_for_states,
+ get_interested_users=self.get_interested_users,
+ )
+
+ async def get_users_for_states(
+ self, state_updates: Iterable[UserPresenceState]
+ ) -> Dict[str, Set[UserPresenceState]]:
+ users_to_state = {
+ user_id: set(state_updates)
+ for user_id in self._config.users_who_should_receive_all_presence
+ }
+ return users_to_state
+
+ async def get_interested_users(
+ self, user_id: str
+ ) -> Union[Set[str], PresenceRouter.ALL_USERS]:
+ if user_id in self._config.users_who_should_receive_all_presence:
+ return PresenceRouter.ALL_USERS
+
+ return set()
+
+ @staticmethod
+ def parse_config(config_dict: dict) -> PresenceRouterTestConfig:
+ """Parse a configuration dictionary from the homeserver config, do
+ some validation and return a typed PresenceRouterConfig.
+
+ Args:
+ config_dict: The configuration dictionary.
+
+ Returns:
+ A validated config object.
+ """
+ # Initialise a typed config object
+ config = PresenceRouterTestConfig()
+
+ config.users_who_should_receive_all_presence = config_dict.get(
+ "users_who_should_receive_all_presence"
+ )
+
+ return config
+
+
class PresenceRouterTestCase(FederatingHomeserverTestCase):
servlets = [
admin.register_servlets,
@@ -86,9 +133,17 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
]
def make_homeserver(self, reactor, clock):
- return self.setup_test_homeserver(
+ hs = self.setup_test_homeserver(
federation_transport_client=Mock(spec=["send_transaction"]),
)
+ # Load the modules into the homeserver
+ module_api = hs.get_module_api()
+ for module, config in hs.config.modules.loaded_modules:
+ module(config=config, api=module_api)
+
+ load_legacy_presence_router(hs)
+
+ return hs
def prepare(self, reactor, clock, homeserver):
self.sync_handler = self.hs.get_sync_handler()
@@ -98,7 +153,7 @@ def prepare(self, reactor, clock, homeserver):
{
"presence": {
"presence_router": {
- "module": __name__ + ".PresenceRouterTestModule",
+ "module": __name__ + ".LegacyPresenceRouterTestModule",
"config": {
"users_who_should_receive_all_presence": [
"@presence_gobbler:test",
@@ -109,7 +164,28 @@ def prepare(self, reactor, clock, homeserver):
"send_federation": True,
}
)
+ def test_receiving_all_presence_legacy(self):
+ self.receiving_all_presence_test_body()
+
+ @override_config(
+ {
+ "modules": [
+ {
+ "module": __name__ + ".PresenceRouterTestModule",
+ "config": {
+ "users_who_should_receive_all_presence": [
+ "@presence_gobbler:test",
+ ]
+ },
+ },
+ ],
+ "send_federation": True,
+ }
+ )
def test_receiving_all_presence(self):
+ self.receiving_all_presence_test_body()
+
+ def receiving_all_presence_test_body(self):
"""Test that a user that does not share a room with another other can receive
presence for them, due to presence routing.
"""
@@ -203,7 +279,7 @@ def test_receiving_all_presence(self):
{
"presence": {
"presence_router": {
- "module": __name__ + ".PresenceRouterTestModule",
+ "module": __name__ + ".LegacyPresenceRouterTestModule",
"config": {
"users_who_should_receive_all_presence": [
"@presence_gobbler1:test",
@@ -216,7 +292,30 @@ def test_receiving_all_presence(self):
"send_federation": True,
}
)
+ def test_send_local_online_presence_to_with_module_legacy(self):
+ self.send_local_online_presence_to_with_module_test_body()
+
+ @override_config(
+ {
+ "modules": [
+ {
+ "module": __name__ + ".PresenceRouterTestModule",
+ "config": {
+ "users_who_should_receive_all_presence": [
+ "@presence_gobbler1:test",
+ "@presence_gobbler2:test",
+ "@far_away_person:island",
+ ]
+ },
+ },
+ ],
+ "send_federation": True,
+ }
+ )
def test_send_local_online_presence_to_with_module(self):
+ self.send_local_online_presence_to_with_module_test_body()
+
+ def send_local_online_presence_to_with_module_test_body(self):
"""Tests that send_local_presence_to_users sends local online presence to a set
of specified local and remote users, with a custom PresenceRouter module enabled.
"""
diff --git a/tests/events/test_utils.py b/tests/events/test_utils.py
index 7a826c086edd..5446fda5e7a3 100644
--- a/tests/events/test_utils.py
+++ b/tests/events/test_utils.py
@@ -322,7 +322,7 @@ def test_join_rules(self):
},
)
- # After MSC3083, alias events have no special behavior.
+ # After MSC3083, the allow key is protected from redaction.
self.run_test(
{
"type": "m.room.join_rules",
@@ -344,6 +344,50 @@ def test_join_rules(self):
room_version=RoomVersions.V8,
)
+ def test_member(self):
+ """Member events have changed behavior starting with MSC3375."""
+ self.run_test(
+ {
+ "type": "m.room.member",
+ "event_id": "$test:domain",
+ "content": {
+ "membership": "join",
+ "join_authorised_via_users_server": "@user:domain",
+ "other_key": "stripped",
+ },
+ },
+ {
+ "type": "m.room.member",
+ "event_id": "$test:domain",
+ "content": {"membership": "join"},
+ "signatures": {},
+ "unsigned": {},
+ },
+ )
+
+ # After MSC3375, the join_authorised_via_users_server key is protected
+ # from redaction.
+ self.run_test(
+ {
+ "type": "m.room.member",
+ "content": {
+ "membership": "join",
+ "join_authorised_via_users_server": "@user:domain",
+ "other_key": "stripped",
+ },
+ },
+ {
+ "type": "m.room.member",
+ "content": {
+ "membership": "join",
+ "join_authorised_via_users_server": "@user:domain",
+ },
+ "signatures": {},
+ "unsigned": {},
+ },
+ room_version=RoomVersions.V9,
+ )
+
class SerializeEventTestCase(unittest.TestCase):
def serialize(self, ev, fields):
diff --git a/tests/federation/transport/test_knocking.py b/tests/federation/transport/test_knocking.py
index 383214ab5046..663960ff534a 100644
--- a/tests/federation/transport/test_knocking.py
+++ b/tests/federation/transport/test_knocking.py
@@ -208,7 +208,7 @@ async def approve_all_signature_checking(_, pdu):
async def _check_event_auth(origin, event, context, *args, **kwargs):
return context
- homeserver.get_federation_handler()._check_event_auth = _check_event_auth
+ homeserver.get_federation_event_handler()._check_event_auth = _check_event_auth
return super().prepare(reactor, clock, homeserver)
diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py
index c72a8972a3f1..6c67a16de923 100644
--- a/tests/handlers/test_federation.py
+++ b/tests/handlers/test_federation.py
@@ -130,7 +130,9 @@ def test_rejected_message_event_state(self):
)
with LoggingContext("send_rejected"):
- d = run_in_background(self.handler.on_receive_pdu, OTHER_SERVER, ev)
+ d = run_in_background(
+ self.hs.get_federation_event_handler().on_receive_pdu, OTHER_SERVER, ev
+ )
self.get_success(d)
# that should have been rejected
@@ -182,7 +184,9 @@ def test_rejected_state_event_state(self):
)
with LoggingContext("send_rejected"):
- d = run_in_background(self.handler.on_receive_pdu, OTHER_SERVER, ev)
+ d = run_in_background(
+ self.hs.get_federation_event_handler().on_receive_pdu, OTHER_SERVER, ev
+ )
self.get_success(d)
# that should have been rejected
@@ -311,7 +315,9 @@ async def get_event_auth(
with LoggingContext("receive_pdu"):
# Fake the OTHER_SERVER federating the message event over to our local homeserver
d = run_in_background(
- self.handler.on_receive_pdu, OTHER_SERVER, message_event
+ self.hs.get_federation_event_handler().on_receive_pdu,
+ OTHER_SERVER,
+ message_event,
)
self.get_success(d)
@@ -382,7 +388,9 @@ def _build_and_send_join_event(self, other_server, other_user, room_id):
join_event.signatures[other_server] = {"x": "y"}
with LoggingContext("send_join"):
d = run_in_background(
- self.handler.on_send_membership_event, other_server, join_event
+ self.hs.get_federation_event_handler().on_send_membership_event,
+ other_server,
+ join_event,
)
self.get_success(d)
diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py
index 0a52bc8b721f..671dc7d083c8 100644
--- a/tests/handlers/test_presence.py
+++ b/tests/handlers/test_presence.py
@@ -885,7 +885,7 @@ def default_config(self):
def prepare(self, reactor, clock, hs):
self.federation_sender = hs.get_federation_sender()
self.event_builder_factory = hs.get_event_builder_factory()
- self.federation_handler = hs.get_federation_handler()
+ self.federation_event_handler = hs.get_federation_event_handler()
self.presence_handler = hs.get_presence_handler()
# self.event_builder_for_2 = EventBuilderFactory(hs)
@@ -1026,7 +1026,7 @@ def _add_new_user(self, room_id, user_id):
builder.build(prev_event_ids=prev_event_ids, auth_event_ids=None)
)
- self.get_success(self.federation_handler.on_receive_pdu(hostname, event))
+ self.get_success(self.federation_event_handler.on_receive_pdu(hostname, event))
# Check that it was successfully persisted.
self.get_success(self.store.get_event(event.event_id))
diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py
index 84f05f6c584c..339c039914f1 100644
--- a/tests/handlers/test_sync.py
+++ b/tests/handlers/test_sync.py
@@ -12,9 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Optional
+
+from synapse.api.constants import EventTypes, JoinRules
from synapse.api.errors import Codes, ResourceLimitError
from synapse.api.filtering import DEFAULT_FILTER_COLLECTION
+from synapse.api.room_versions import RoomVersions
from synapse.handlers.sync import SyncConfig
+from synapse.rest import admin
+from synapse.rest.client import knock, login, room
+from synapse.server import HomeServer
from synapse.types import UserID, create_requester
import tests.unittest
@@ -24,8 +31,14 @@
class SyncTestCase(tests.unittest.HomeserverTestCase):
"""Tests Sync Handler."""
- def prepare(self, reactor, clock, hs):
- self.hs = hs
+ servlets = [
+ admin.register_servlets,
+ knock.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs: HomeServer):
self.sync_handler = self.hs.get_sync_handler()
self.store = self.hs.get_datastore()
@@ -68,12 +81,124 @@ def test_wait_for_sync_for_user_auth_blocking(self):
)
self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
+ def test_unknown_room_version(self):
+ """
+ A room with an unknown room version should not break sync (and should be excluded).
+ """
+ inviter = self.register_user("creator", "pass", admin=True)
+ inviter_tok = self.login("@creator:test", "pass")
+
+ user = self.register_user("user", "pass")
+ tok = self.login("user", "pass")
+
+ # Do an initial sync on a different device.
+ requester = create_requester(user)
+ initial_result = self.get_success(
+ self.sync_handler.wait_for_sync_for_user(
+ requester, sync_config=generate_sync_config(user, device_id="dev")
+ )
+ )
+
+ # Create a room as the user.
+ joined_room = self.helper.create_room_as(user, tok=tok)
+
+ # Invite the user to the room as someone else.
+ invite_room = self.helper.create_room_as(inviter, tok=inviter_tok)
+ self.helper.invite(invite_room, targ=user, tok=inviter_tok)
+
+ knock_room = self.helper.create_room_as(
+ inviter, room_version=RoomVersions.V7.identifier, tok=inviter_tok
+ )
+ self.helper.send_state(
+ knock_room,
+ EventTypes.JoinRules,
+ {"join_rule": JoinRules.KNOCK},
+ tok=inviter_tok,
+ )
+ channel = self.make_request(
+ "POST",
+ "/_matrix/client/r0/knock/%s" % (knock_room,),
+ b"{}",
+ tok,
+ )
+ self.assertEquals(200, channel.code, channel.result)
+
+ # The rooms should appear in the sync response.
+ result = self.get_success(
+ self.sync_handler.wait_for_sync_for_user(
+ requester, sync_config=generate_sync_config(user)
+ )
+ )
+ self.assertIn(joined_room, [r.room_id for r in result.joined])
+ self.assertIn(invite_room, [r.room_id for r in result.invited])
+ self.assertIn(knock_room, [r.room_id for r in result.knocked])
+
+ # Test a incremental sync (by providing a since_token).
+ result = self.get_success(
+ self.sync_handler.wait_for_sync_for_user(
+ requester,
+ sync_config=generate_sync_config(user, device_id="dev"),
+ since_token=initial_result.next_batch,
+ )
+ )
+ self.assertIn(joined_room, [r.room_id for r in result.joined])
+ self.assertIn(invite_room, [r.room_id for r in result.invited])
+ self.assertIn(knock_room, [r.room_id for r in result.knocked])
+
+ # Poke the database and update the room version to an unknown one.
+ for room_id in (joined_room, invite_room, knock_room):
+ self.get_success(
+ self.hs.get_datastores().main.db_pool.simple_update(
+ "rooms",
+ keyvalues={"room_id": room_id},
+ updatevalues={"room_version": "unknown-room-version"},
+ desc="updated-room-version",
+ )
+ )
+
+ # Blow away caches (supported room versions can only change due to a restart).
+ self.get_success(
+ self.store.get_rooms_for_user_with_stream_ordering.invalidate_all()
+ )
+ self.store._get_event_cache.clear()
+
+ # The rooms should be excluded from the sync response.
+ # Get a new request key.
+ result = self.get_success(
+ self.sync_handler.wait_for_sync_for_user(
+ requester, sync_config=generate_sync_config(user)
+ )
+ )
+ self.assertNotIn(joined_room, [r.room_id for r in result.joined])
+ self.assertNotIn(invite_room, [r.room_id for r in result.invited])
+ self.assertNotIn(knock_room, [r.room_id for r in result.knocked])
+
+ # The rooms should also not be in an incremental sync.
+ result = self.get_success(
+ self.sync_handler.wait_for_sync_for_user(
+ requester,
+ sync_config=generate_sync_config(user, device_id="dev"),
+ since_token=initial_result.next_batch,
+ )
+ )
+ self.assertNotIn(joined_room, [r.room_id for r in result.joined])
+ self.assertNotIn(invite_room, [r.room_id for r in result.invited])
+ self.assertNotIn(knock_room, [r.room_id for r in result.knocked])
+
+
+_request_key = 0
+
-def generate_sync_config(user_id: str) -> SyncConfig:
+def generate_sync_config(
+ user_id: str, device_id: Optional[str] = "device_id"
+) -> SyncConfig:
+ """Generate a sync config (with a unique request key)."""
+ global _request_key
+ _request_key += 1
return SyncConfig(
- user=UserID(user_id.split(":")[0][1:], user_id.split(":")[1]),
+ user=UserID.from_string(user_id),
filter_collection=DEFAULT_FILTER_COLLECTION,
is_guest=False,
- request_key="request_key",
- device_id="device_id",
+ request_key=("request_key", _request_key),
+ device_id=device_id,
)
diff --git a/tests/push/test_email.py b/tests/push/test_email.py
index e0a3342088d4..c4ba13a6b270 100644
--- a/tests/push/test_email.py
+++ b/tests/push/test_email.py
@@ -125,6 +125,8 @@ def prepare(self, reactor, clock, hs):
)
)
+ self.auth_handler = hs.get_auth_handler()
+
def test_need_validated_email(self):
"""Test that we can only add an email pusher if the user has validated
their email.
@@ -305,6 +307,87 @@ def test_encrypted_message(self):
# We should get emailed about that message
self._check_for_mail()
+ def test_no_email_sent_after_removed(self):
+ # Create a simple room with two users
+ room = self.helper.create_room_as(self.user_id, tok=self.access_token)
+ self.helper.invite(
+ room=room,
+ src=self.user_id,
+ tok=self.access_token,
+ targ=self.others[0].id,
+ )
+ self.helper.join(
+ room=room,
+ user=self.others[0].id,
+ tok=self.others[0].token,
+ )
+
+ # The other user sends a single message.
+ self.helper.send(room, body="Hi!", tok=self.others[0].token)
+
+ # We should get emailed about that message
+ self._check_for_mail()
+
+ # disassociate the user's email address
+ self.get_success(
+ self.auth_handler.delete_threepid(
+ user_id=self.user_id,
+ medium="email",
+ address="a@example.com",
+ )
+ )
+
+ # check that the pusher for that email address has been deleted
+ pushers = self.get_success(
+ self.hs.get_datastore().get_pushers_by({"user_name": self.user_id})
+ )
+ pushers = list(pushers)
+ self.assertEqual(len(pushers), 0)
+
+ def test_remove_unlinked_pushers_background_job(self):
+ """Checks that all existing pushers associated with unlinked email addresses are removed
+ upon running the remove_deleted_email_pushers background update.
+ """
+ # disassociate the user's email address manually (without deleting the pusher).
+ # This resembles the old behaviour, which the background update below is intended
+ # to clean up.
+ self.get_success(
+ self.hs.get_datastore().user_delete_threepid(
+ self.user_id, "email", "a@example.com"
+ )
+ )
+
+ # Run the "remove_deleted_email_pushers" background job
+ self.get_success(
+ self.hs.get_datastore().db_pool.simple_insert(
+ table="background_updates",
+ values={
+ "update_name": "remove_deleted_email_pushers",
+ "progress_json": "{}",
+ "depends_on": None,
+ },
+ )
+ )
+
+ # ... and tell the DataStore that it hasn't finished all updates yet
+ self.hs.get_datastore().db_pool.updates._all_done = False
+
+ # Now let's actually drive the updates to completion
+ while not self.get_success(
+ self.hs.get_datastore().db_pool.updates.has_completed_background_updates()
+ ):
+ self.get_success(
+ self.hs.get_datastore().db_pool.updates.do_next_background_update(100),
+ by=0.1,
+ )
+
+ # Check that all pushers with unlinked addresses were deleted
+ pushers = self.get_success(
+ self.hs.get_datastore().get_pushers_by({"user_name": self.user_id})
+ )
+ pushers = list(pushers)
+ self.assertEqual(len(pushers), 0)
+
def _check_for_mail(self):
"""Check that the user receives an email notification"""
diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py
index db80a0bdbdaf..b25a06b4271a 100644
--- a/tests/replication/slave/storage/test_events.py
+++ b/tests/replication/slave/storage/test_events.py
@@ -20,7 +20,7 @@
from synapse.events import FrozenEvent, _EventInternalMetadata, make_event_from_dict
from synapse.handlers.room import RoomEventSource
from synapse.replication.slave.storage.events import SlavedEventStore
-from synapse.storage.roommember import RoomsForUser
+from synapse.storage.roommember import GetRoomsForUserWithStreamOrdering, RoomsForUser
from synapse.types import PersistedEventPosition
from tests.server import FakeTransport
@@ -150,6 +150,7 @@ def test_invites(self):
"invite",
event.event_id,
event.internal_metadata.stream_ordering,
+ RoomVersions.V1.identifier,
)
],
)
@@ -216,7 +217,7 @@ def test_get_rooms_for_user_with_stream_ordering(self):
self.check(
"get_rooms_for_user_with_stream_ordering",
(USER_ID_2,),
- {(ROOM_ID, expected_pos)},
+ {GetRoomsForUserWithStreamOrdering(ROOM_ID, expected_pos)},
)
def test_get_rooms_for_user_with_stream_ordering_with_multi_event_persist(self):
@@ -305,7 +306,10 @@ def test_get_rooms_for_user_with_stream_ordering_with_multi_event_persist(self):
expected_pos = PersistedEventPosition(
"master", j2.internal_metadata.stream_ordering
)
- self.assertEqual(joined_rooms, {(ROOM_ID, expected_pos)})
+ self.assertEqual(
+ joined_rooms,
+ {GetRoomsForUserWithStreamOrdering(ROOM_ID, expected_pos)},
+ )
event_id = 0
diff --git a/tests/replication/test_federation_sender_shard.py b/tests/replication/test_federation_sender_shard.py
index af5dfca752b9..92a5b53e11e7 100644
--- a/tests/replication/test_federation_sender_shard.py
+++ b/tests/replication/test_federation_sender_shard.py
@@ -205,7 +205,7 @@ def test_send_typing_sharded(self):
def create_room_with_remote_server(self, user, token, remote_server="other_server"):
room = self.helper.create_room_as(user, tok=token)
store = self.hs.get_datastore()
- federation = self.hs.get_federation_handler()
+ federation = self.hs.get_federation_event_handler()
prev_event_ids = self.get_success(store.get_latest_event_ids_in_room(room))
room_version = self.get_success(store.get_room_version(room))
diff --git a/tests/rest/admin/test_device.py b/tests/rest/admin/test_device.py
index c4afe5c3d90b..a3679be20539 100644
--- a/tests/rest/admin/test_device.py
+++ b/tests/rest/admin/test_device.py
@@ -12,9 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import json
import urllib.parse
+from parameterized import parameterized
+
import synapse.rest.admin
from synapse.api.errors import Codes
from synapse.rest.client import login
@@ -45,49 +46,23 @@ def prepare(self, reactor, clock, hs):
self.other_user_device_id,
)
- def test_no_auth(self):
+ @parameterized.expand(["GET", "PUT", "DELETE"])
+ def test_no_auth(self, method: str):
"""
Try to get a device of an user without authentication.
"""
- channel = self.make_request("GET", self.url, b"{}")
-
- self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
-
- channel = self.make_request("PUT", self.url, b"{}")
-
- self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
-
- channel = self.make_request("DELETE", self.url, b"{}")
+ channel = self.make_request(method, self.url, b"{}")
self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
- def test_requester_is_no_admin(self):
+ @parameterized.expand(["GET", "PUT", "DELETE"])
+ def test_requester_is_no_admin(self, method: str):
"""
If the user is not a server admin, an error is returned.
"""
channel = self.make_request(
- "GET",
- self.url,
- access_token=self.other_user_token,
- )
-
- self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
-
- channel = self.make_request(
- "PUT",
- self.url,
- access_token=self.other_user_token,
- )
-
- self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
-
- channel = self.make_request(
- "DELETE",
+ method,
self.url,
access_token=self.other_user_token,
)
@@ -95,7 +70,8 @@ def test_requester_is_no_admin(self):
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
- def test_user_does_not_exist(self):
+ @parameterized.expand(["GET", "PUT", "DELETE"])
+ def test_user_does_not_exist(self, method: str):
"""
Tests that a lookup for a user that does not exist returns a 404
"""
@@ -105,7 +81,7 @@ def test_user_does_not_exist(self):
)
channel = self.make_request(
- "GET",
+ method,
url,
access_token=self.admin_user_tok,
)
@@ -113,25 +89,8 @@ def test_user_does_not_exist(self):
self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
- channel = self.make_request(
- "PUT",
- url,
- access_token=self.admin_user_tok,
- )
-
- self.assertEqual(404, channel.code, msg=channel.json_body)
- self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
-
- channel = self.make_request(
- "DELETE",
- url,
- access_token=self.admin_user_tok,
- )
-
- self.assertEqual(404, channel.code, msg=channel.json_body)
- self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
-
- def test_user_is_not_local(self):
+ @parameterized.expand(["GET", "PUT", "DELETE"])
+ def test_user_is_not_local(self, method: str):
"""
Tests that a lookup for a user that is not a local returns a 400
"""
@@ -141,25 +100,7 @@ def test_user_is_not_local(self):
)
channel = self.make_request(
- "GET",
- url,
- access_token=self.admin_user_tok,
- )
-
- self.assertEqual(400, channel.code, msg=channel.json_body)
- self.assertEqual("Can only lookup local users", channel.json_body["error"])
-
- channel = self.make_request(
- "PUT",
- url,
- access_token=self.admin_user_tok,
- )
-
- self.assertEqual(400, channel.code, msg=channel.json_body)
- self.assertEqual("Can only lookup local users", channel.json_body["error"])
-
- channel = self.make_request(
- "DELETE",
+ method,
url,
access_token=self.admin_user_tok,
)
@@ -219,12 +160,11 @@ def test_update_device_too_long_display_name(self):
* (synapse.handlers.device.MAX_DEVICE_DISPLAY_NAME_LEN + 1)
}
- body = json.dumps(update)
channel = self.make_request(
"PUT",
self.url,
access_token=self.admin_user_tok,
- content=body.encode(encoding="utf_8"),
+ content=update,
)
self.assertEqual(400, channel.code, msg=channel.json_body)
@@ -275,12 +215,11 @@ def test_update_display_name(self):
Tests a normal successful update of display name
"""
# Set new display_name
- body = json.dumps({"display_name": "new displayname"})
channel = self.make_request(
"PUT",
self.url,
access_token=self.admin_user_tok,
- content=body.encode(encoding="utf_8"),
+ content={"display_name": "new displayname"},
)
self.assertEqual(200, channel.code, msg=channel.json_body)
@@ -529,12 +468,11 @@ def test_unknown_devices(self):
"""
Tests that a remove of a device that does not exist returns 200.
"""
- body = json.dumps({"devices": ["unknown_device1", "unknown_device2"]})
channel = self.make_request(
"POST",
self.url,
access_token=self.admin_user_tok,
- content=body.encode(encoding="utf_8"),
+ content={"devices": ["unknown_device1", "unknown_device2"]},
)
# Delete unknown devices returns status 200
@@ -560,12 +498,11 @@ def test_delete_devices(self):
device_ids.append(str(d["device_id"]))
# Delete devices
- body = json.dumps({"devices": device_ids})
channel = self.make_request(
"POST",
self.url,
access_token=self.admin_user_tok,
- content=body.encode(encoding="utf_8"),
+ content={"devices": device_ids},
)
self.assertEqual(200, channel.code, msg=channel.json_body)
diff --git a/tests/rest/admin/test_registration_tokens.py b/tests/rest/admin/test_registration_tokens.py
new file mode 100644
index 000000000000..4927321e5a4f
--- /dev/null
+++ b/tests/rest/admin/test_registration_tokens.py
@@ -0,0 +1,710 @@
+# Copyright 2021 Callum Brown
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import random
+import string
+
+import synapse.rest.admin
+from synapse.api.errors import Codes
+from synapse.rest.client import login
+
+from tests import unittest
+
+
+class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ self.other_user = self.register_user("user", "pass")
+ self.other_user_tok = self.login("user", "pass")
+
+ self.url = "/_synapse/admin/v1/registration_tokens"
+
+ def _new_token(self, **kwargs):
+ """Helper function to create a token."""
+ token = kwargs.get(
+ "token",
+ "".join(random.choices(string.ascii_letters, k=8)),
+ )
+ self.get_success(
+ self.store.db_pool.simple_insert(
+ "registration_tokens",
+ {
+ "token": token,
+ "uses_allowed": kwargs.get("uses_allowed", None),
+ "pending": kwargs.get("pending", 0),
+ "completed": kwargs.get("completed", 0),
+ "expiry_time": kwargs.get("expiry_time", None),
+ },
+ )
+ )
+ return token
+
+ # CREATION
+
+ def test_create_no_auth(self):
+ """Try to create a token without authentication."""
+ channel = self.make_request("POST", self.url + "/new", {})
+ self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+ def test_create_requester_not_admin(self):
+ """Try to create a token while not an admin."""
+ channel = self.make_request(
+ "POST",
+ self.url + "/new",
+ {},
+ access_token=self.other_user_tok,
+ )
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ def test_create_using_defaults(self):
+ """Create a token using all the defaults."""
+ channel = self.make_request(
+ "POST",
+ self.url + "/new",
+ {},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(len(channel.json_body["token"]), 16)
+ self.assertIsNone(channel.json_body["uses_allowed"])
+ self.assertIsNone(channel.json_body["expiry_time"])
+ self.assertEqual(channel.json_body["pending"], 0)
+ self.assertEqual(channel.json_body["completed"], 0)
+
+ def test_create_specifying_fields(self):
+ """Create a token specifying the value of all fields."""
+ data = {
+ "token": "abcd",
+ "uses_allowed": 1,
+ "expiry_time": self.clock.time_msec() + 1000000,
+ }
+
+ channel = self.make_request(
+ "POST",
+ self.url + "/new",
+ data,
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["token"], "abcd")
+ self.assertEqual(channel.json_body["uses_allowed"], 1)
+ self.assertEqual(channel.json_body["expiry_time"], data["expiry_time"])
+ self.assertEqual(channel.json_body["pending"], 0)
+ self.assertEqual(channel.json_body["completed"], 0)
+
+ def test_create_with_null_value(self):
+ """Create a token specifying unlimited uses and no expiry."""
+ data = {
+ "uses_allowed": None,
+ "expiry_time": None,
+ }
+
+ channel = self.make_request(
+ "POST",
+ self.url + "/new",
+ data,
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(len(channel.json_body["token"]), 16)
+ self.assertIsNone(channel.json_body["uses_allowed"])
+ self.assertIsNone(channel.json_body["expiry_time"])
+ self.assertEqual(channel.json_body["pending"], 0)
+ self.assertEqual(channel.json_body["completed"], 0)
+
+ def test_create_token_too_long(self):
+ """Check token longer than 64 chars is invalid."""
+ data = {"token": "a" * 65}
+
+ channel = self.make_request(
+ "POST",
+ self.url + "/new",
+ data,
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+ def test_create_token_invalid_chars(self):
+ """Check you can't create token with invalid characters."""
+ data = {
+ "token": "abc/def",
+ }
+
+ channel = self.make_request(
+ "POST",
+ self.url + "/new",
+ data,
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+ def test_create_token_already_exists(self):
+ """Check you can't create token that already exists."""
+ data = {
+ "token": "abcd",
+ }
+
+ channel1 = self.make_request(
+ "POST",
+ self.url + "/new",
+ data,
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, int(channel1.result["code"]), msg=channel1.result["body"])
+
+ channel2 = self.make_request(
+ "POST",
+ self.url + "/new",
+ data,
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(400, int(channel2.result["code"]), msg=channel2.result["body"])
+ self.assertEqual(channel2.json_body["errcode"], Codes.INVALID_PARAM)
+
+ def test_create_unable_to_generate_token(self):
+ """Check right error is raised when server can't generate unique token."""
+ # Create all possible single character tokens
+ tokens = []
+ for c in string.ascii_letters + string.digits + "-_":
+ tokens.append(
+ {
+ "token": c,
+ "uses_allowed": None,
+ "pending": 0,
+ "completed": 0,
+ "expiry_time": None,
+ }
+ )
+ self.get_success(
+ self.store.db_pool.simple_insert_many(
+ "registration_tokens",
+ tokens,
+ "create_all_registration_tokens",
+ )
+ )
+
+ # Check creating a single character token fails with a 500 status code
+ channel = self.make_request(
+ "POST",
+ self.url + "/new",
+ {"length": 1},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(500, int(channel.result["code"]), msg=channel.result["body"])
+
+ def test_create_uses_allowed(self):
+ """Check you can only create a token with good values for uses_allowed."""
+ # Should work with 0 (token is invalid from the start)
+ channel = self.make_request(
+ "POST",
+ self.url + "/new",
+ {"uses_allowed": 0},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["uses_allowed"], 0)
+
+ # Should fail with negative integer
+ channel = self.make_request(
+ "POST",
+ self.url + "/new",
+ {"uses_allowed": -5},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+ # Should fail with float
+ channel = self.make_request(
+ "POST",
+ self.url + "/new",
+ {"uses_allowed": 1.5},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+ def test_create_expiry_time(self):
+ """Check you can't create a token with an invalid expiry_time."""
+ # Should fail with a time in the past
+ channel = self.make_request(
+ "POST",
+ self.url + "/new",
+ {"expiry_time": self.clock.time_msec() - 10000},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+ # Should fail with float
+ channel = self.make_request(
+ "POST",
+ self.url + "/new",
+ {"expiry_time": self.clock.time_msec() + 1000000.5},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+ def test_create_length(self):
+ """Check you can only generate a token with a valid length."""
+ # Should work with 64
+ channel = self.make_request(
+ "POST",
+ self.url + "/new",
+ {"length": 64},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(len(channel.json_body["token"]), 64)
+
+ # Should fail with 0
+ channel = self.make_request(
+ "POST",
+ self.url + "/new",
+ {"length": 0},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+ # Should fail with a negative integer
+ channel = self.make_request(
+ "POST",
+ self.url + "/new",
+ {"length": -5},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+ # Should fail with a float
+ channel = self.make_request(
+ "POST",
+ self.url + "/new",
+ {"length": 8.5},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+ # Should fail with 65
+ channel = self.make_request(
+ "POST",
+ self.url + "/new",
+ {"length": 65},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+ # UPDATING
+
+ def test_update_no_auth(self):
+ """Try to update a token without authentication."""
+ channel = self.make_request(
+ "PUT",
+ self.url + "/1234", # Token doesn't exist but that doesn't matter
+ {},
+ )
+ self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+ def test_update_requester_not_admin(self):
+ """Try to update a token while not an admin."""
+ channel = self.make_request(
+ "PUT",
+ self.url + "/1234", # Token doesn't exist but that doesn't matter
+ {},
+ access_token=self.other_user_tok,
+ )
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ def test_update_non_existent(self):
+ """Try to update a token that doesn't exist."""
+ channel = self.make_request(
+ "PUT",
+ self.url + "/1234",
+ {"uses_allowed": 1},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
+
+ def test_update_uses_allowed(self):
+ """Test updating just uses_allowed."""
+ # Create new token using default values
+ token = self._new_token()
+
+ # Should succeed with 1
+ channel = self.make_request(
+ "PUT",
+ self.url + "/" + token,
+ {"uses_allowed": 1},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["uses_allowed"], 1)
+ self.assertIsNone(channel.json_body["expiry_time"])
+
+ # Should succeed with 0 (makes token invalid)
+ channel = self.make_request(
+ "PUT",
+ self.url + "/" + token,
+ {"uses_allowed": 0},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["uses_allowed"], 0)
+ self.assertIsNone(channel.json_body["expiry_time"])
+
+ # Should succeed with null
+ channel = self.make_request(
+ "PUT",
+ self.url + "/" + token,
+ {"uses_allowed": None},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertIsNone(channel.json_body["uses_allowed"])
+ self.assertIsNone(channel.json_body["expiry_time"])
+
+ # Should fail with a float
+ channel = self.make_request(
+ "PUT",
+ self.url + "/" + token,
+ {"uses_allowed": 1.5},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+ # Should fail with a negative integer
+ channel = self.make_request(
+ "PUT",
+ self.url + "/" + token,
+ {"uses_allowed": -5},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+ def test_update_expiry_time(self):
+ """Test updating just expiry_time."""
+ # Create new token using default values
+ token = self._new_token()
+ new_expiry_time = self.clock.time_msec() + 1000000
+
+ # Should succeed with a time in the future
+ channel = self.make_request(
+ "PUT",
+ self.url + "/" + token,
+ {"expiry_time": new_expiry_time},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["expiry_time"], new_expiry_time)
+ self.assertIsNone(channel.json_body["uses_allowed"])
+
+ # Should succeed with null
+ channel = self.make_request(
+ "PUT",
+ self.url + "/" + token,
+ {"expiry_time": None},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertIsNone(channel.json_body["expiry_time"])
+ self.assertIsNone(channel.json_body["uses_allowed"])
+
+ # Should fail with a time in the past
+ past_time = self.clock.time_msec() - 10000
+ channel = self.make_request(
+ "PUT",
+ self.url + "/" + token,
+ {"expiry_time": past_time},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+ # Should fail a float
+ channel = self.make_request(
+ "PUT",
+ self.url + "/" + token,
+ {"expiry_time": new_expiry_time + 0.5},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+ def test_update_both(self):
+ """Test updating both uses_allowed and expiry_time."""
+ # Create new token using default values
+ token = self._new_token()
+ new_expiry_time = self.clock.time_msec() + 1000000
+
+ data = {
+ "uses_allowed": 1,
+ "expiry_time": new_expiry_time,
+ }
+
+ channel = self.make_request(
+ "PUT",
+ self.url + "/" + token,
+ data,
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["uses_allowed"], 1)
+ self.assertEqual(channel.json_body["expiry_time"], new_expiry_time)
+
+ def test_update_invalid_type(self):
+ """Test using invalid types doesn't work."""
+ # Create new token using default values
+ token = self._new_token()
+
+ data = {
+ "uses_allowed": False,
+ "expiry_time": "1626430124000",
+ }
+
+ channel = self.make_request(
+ "PUT",
+ self.url + "/" + token,
+ data,
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+ # DELETING
+
+ def test_delete_no_auth(self):
+ """Try to delete a token without authentication."""
+ channel = self.make_request(
+ "DELETE",
+ self.url + "/1234", # Token doesn't exist but that doesn't matter
+ {},
+ )
+ self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+ def test_delete_requester_not_admin(self):
+ """Try to delete a token while not an admin."""
+ channel = self.make_request(
+ "DELETE",
+ self.url + "/1234", # Token doesn't exist but that doesn't matter
+ {},
+ access_token=self.other_user_tok,
+ )
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ def test_delete_non_existent(self):
+ """Try to delete a token that doesn't exist."""
+ channel = self.make_request(
+ "DELETE",
+ self.url + "/1234",
+ {},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
+
+ def test_delete(self):
+ """Test deleting a token."""
+ # Create new token using default values
+ token = self._new_token()
+
+ channel = self.make_request(
+ "DELETE",
+ self.url + "/" + token,
+ {},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ # GETTING ONE
+
+ def test_get_no_auth(self):
+ """Try to get a token without authentication."""
+ channel = self.make_request(
+ "GET",
+ self.url + "/1234", # Token doesn't exist but that doesn't matter
+ {},
+ )
+ self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+ def test_get_requester_not_admin(self):
+ """Try to get a token while not an admin."""
+ channel = self.make_request(
+ "GET",
+ self.url + "/1234", # Token doesn't exist but that doesn't matter
+ {},
+ access_token=self.other_user_tok,
+ )
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ def test_get_non_existent(self):
+ """Try to get a token that doesn't exist."""
+ channel = self.make_request(
+ "GET",
+ self.url + "/1234",
+ {},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
+
+ def test_get(self):
+ """Test getting a token."""
+ # Create new token using default values
+ token = self._new_token()
+
+ channel = self.make_request(
+ "GET",
+ self.url + "/" + token,
+ {},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["token"], token)
+ self.assertIsNone(channel.json_body["uses_allowed"])
+ self.assertIsNone(channel.json_body["expiry_time"])
+ self.assertEqual(channel.json_body["pending"], 0)
+ self.assertEqual(channel.json_body["completed"], 0)
+
+ # LISTING
+
+ def test_list_no_auth(self):
+ """Try to list tokens without authentication."""
+ channel = self.make_request("GET", self.url, {})
+ self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+ def test_list_requester_not_admin(self):
+ """Try to list tokens while not an admin."""
+ channel = self.make_request(
+ "GET",
+ self.url,
+ {},
+ access_token=self.other_user_tok,
+ )
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ def test_list_all(self):
+ """Test listing all tokens."""
+ # Create new token using default values
+ token = self._new_token()
+
+ channel = self.make_request(
+ "GET",
+ self.url,
+ {},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(len(channel.json_body["registration_tokens"]), 1)
+ token_info = channel.json_body["registration_tokens"][0]
+ self.assertEqual(token_info["token"], token)
+ self.assertIsNone(token_info["uses_allowed"])
+ self.assertIsNone(token_info["expiry_time"])
+ self.assertEqual(token_info["pending"], 0)
+ self.assertEqual(token_info["completed"], 0)
+
+ def test_list_invalid_query_parameter(self):
+ """Test with `valid` query parameter not `true` or `false`."""
+ channel = self.make_request(
+ "GET",
+ self.url + "?valid=x",
+ {},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+
+ def _test_list_query_parameter(self, valid: str):
+ """Helper used to test both valid=true and valid=false."""
+ # Create 2 valid and 2 invalid tokens.
+ now = self.hs.get_clock().time_msec()
+ # Create always valid token
+ valid1 = self._new_token()
+ # Create token that hasn't been used up
+ valid2 = self._new_token(uses_allowed=1)
+ # Create token that has expired
+ invalid1 = self._new_token(expiry_time=now - 10000)
+ # Create token that has been used up but hasn't expired
+ invalid2 = self._new_token(
+ uses_allowed=2,
+ pending=1,
+ completed=1,
+ expiry_time=now + 1000000,
+ )
+
+ if valid == "true":
+ tokens = [valid1, valid2]
+ else:
+ tokens = [invalid1, invalid2]
+
+ channel = self.make_request(
+ "GET",
+ self.url + "?valid=" + valid,
+ {},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(len(channel.json_body["registration_tokens"]), 2)
+ token_info_1 = channel.json_body["registration_tokens"][0]
+ token_info_2 = channel.json_body["registration_tokens"][1]
+ self.assertIn(token_info_1["token"], tokens)
+ self.assertIn(token_info_2["token"], tokens)
+
+ def test_list_valid(self):
+ """Test listing just valid tokens."""
+ self._test_list_query_parameter(valid="true")
+
+ def test_list_invalid(self):
+ """Test listing just invalid tokens."""
+ self._test_list_query_parameter(valid="false")
diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py
index c9d4731017c1..40e032df7f8c 100644
--- a/tests/rest/admin/test_room.py
+++ b/tests/rest/admin/test_room.py
@@ -29,123 +29,6 @@
"""Tests admin REST events for /rooms paths."""
-class ShutdownRoomTestCase(unittest.HomeserverTestCase):
- servlets = [
- synapse.rest.admin.register_servlets_for_client_rest_resource,
- login.register_servlets,
- events.register_servlets,
- room.register_servlets,
- room.register_deprecated_servlets,
- ]
-
- def prepare(self, reactor, clock, hs):
- self.event_creation_handler = hs.get_event_creation_handler()
- hs.config.user_consent_version = "1"
-
- consent_uri_builder = Mock()
- consent_uri_builder.build_user_consent_uri.return_value = "http://example.com"
- self.event_creation_handler._consent_uri_builder = consent_uri_builder
-
- self.store = hs.get_datastore()
-
- self.admin_user = self.register_user("admin", "pass", admin=True)
- self.admin_user_tok = self.login("admin", "pass")
-
- self.other_user = self.register_user("user", "pass")
- self.other_user_token = self.login("user", "pass")
-
- # Mark the admin user as having consented
- self.get_success(self.store.user_set_consent_version(self.admin_user, "1"))
-
- def test_shutdown_room_consent(self):
- """Test that we can shutdown rooms with local users who have not
- yet accepted the privacy policy. This used to fail when we tried to
- force part the user from the old room.
- """
- self.event_creation_handler._block_events_without_consent_error = None
-
- room_id = self.helper.create_room_as(self.other_user, tok=self.other_user_token)
-
- # Assert one user in room
- users_in_room = self.get_success(self.store.get_users_in_room(room_id))
- self.assertEqual([self.other_user], users_in_room)
-
- # Enable require consent to send events
- self.event_creation_handler._block_events_without_consent_error = "Error"
-
- # Assert that the user is getting consent error
- self.helper.send(
- room_id, body="foo", tok=self.other_user_token, expect_code=403
- )
-
- # Test that the admin can still send shutdown
- url = "/_synapse/admin/v1/shutdown_room/" + room_id
- channel = self.make_request(
- "POST",
- url.encode("ascii"),
- json.dumps({"new_room_user_id": self.admin_user}),
- access_token=self.admin_user_tok,
- )
-
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
-
- # Assert there is now no longer anyone in the room
- users_in_room = self.get_success(self.store.get_users_in_room(room_id))
- self.assertEqual([], users_in_room)
-
- def test_shutdown_room_block_peek(self):
- """Test that a world_readable room can no longer be peeked into after
- it has been shut down.
- """
-
- self.event_creation_handler._block_events_without_consent_error = None
-
- room_id = self.helper.create_room_as(self.other_user, tok=self.other_user_token)
-
- # Enable world readable
- url = "rooms/%s/state/m.room.history_visibility" % (room_id,)
- channel = self.make_request(
- "PUT",
- url.encode("ascii"),
- json.dumps({"history_visibility": "world_readable"}),
- access_token=self.other_user_token,
- )
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
-
- # Test that the admin can still send shutdown
- url = "/_synapse/admin/v1/shutdown_room/" + room_id
- channel = self.make_request(
- "POST",
- url.encode("ascii"),
- json.dumps({"new_room_user_id": self.admin_user}),
- access_token=self.admin_user_tok,
- )
-
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
-
- # Assert we can no longer peek into the room
- self._assert_peek(room_id, expect_code=403)
-
- def _assert_peek(self, room_id, expect_code):
- """Assert that the admin user can (or cannot) peek into the room."""
-
- url = "rooms/%s/initialSync" % (room_id,)
- channel = self.make_request(
- "GET", url.encode("ascii"), access_token=self.admin_user_tok
- )
- self.assertEqual(
- expect_code, int(channel.result["code"]), msg=channel.result["body"]
- )
-
- url = "events?timeout=0&room_id=" + room_id
- channel = self.make_request(
- "GET", url.encode("ascii"), access_token=self.admin_user_tok
- )
- self.assertEqual(
- expect_code, int(channel.result["code"]), msg=channel.result["body"]
- )
-
-
@parameterized_class(
("method", "url_template"),
[
@@ -557,51 +440,6 @@ def _assert_peek(self, room_id, expect_code):
)
-class PurgeRoomTestCase(unittest.HomeserverTestCase):
- """Test /purge_room admin API."""
-
- servlets = [
- synapse.rest.admin.register_servlets,
- login.register_servlets,
- room.register_servlets,
- ]
-
- def prepare(self, reactor, clock, hs):
- self.store = hs.get_datastore()
-
- self.admin_user = self.register_user("admin", "pass", admin=True)
- self.admin_user_tok = self.login("admin", "pass")
-
- def test_purge_room(self):
- room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
-
- # All users have to have left the room.
- self.helper.leave(room_id, user=self.admin_user, tok=self.admin_user_tok)
-
- url = "/_synapse/admin/v1/purge_room"
- channel = self.make_request(
- "POST",
- url.encode("ascii"),
- {"room_id": room_id},
- access_token=self.admin_user_tok,
- )
-
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
-
- # Test that the following tables have been purged of all rows related to the room.
- for table in PURGE_TABLES:
- count = self.get_success(
- self.store.db_pool.simple_select_one_onecol(
- table=table,
- keyvalues={"room_id": room_id},
- retcol="COUNT(*)",
- desc="test_purge_room",
- )
- )
-
- self.assertEqual(count, 0, msg=f"Rows not purged in {table}")
-
-
class RoomTestCase(unittest.HomeserverTestCase):
"""Test /room admin API."""
diff --git a/tests/rest/admin/test_server_notice.py b/tests/rest/admin/test_server_notice.py
new file mode 100644
index 000000000000..fbceba325494
--- /dev/null
+++ b/tests/rest/admin/test_server_notice.py
@@ -0,0 +1,450 @@
+# Copyright 2021 Dirk Klimpel
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import List
+
+import synapse.rest.admin
+from synapse.api.errors import Codes
+from synapse.rest.client import login, room, sync
+from synapse.storage.roommember import RoomsForUser
+from synapse.types import JsonDict
+
+from tests import unittest
+from tests.unittest import override_config
+
+
+class ServerNoticeTestCase(unittest.HomeserverTestCase):
+
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ sync.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
+ self.room_shutdown_handler = hs.get_room_shutdown_handler()
+ self.pagination_handler = hs.get_pagination_handler()
+ self.server_notices_manager = self.hs.get_server_notices_manager()
+
+ # Create user
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ self.other_user = self.register_user("user", "pass")
+ self.other_user_token = self.login("user", "pass")
+
+ self.url = "/_synapse/admin/v1/send_server_notice"
+
+ def test_no_auth(self):
+ """Try to send a server notice without authentication."""
+ channel = self.make_request("POST", self.url)
+
+ self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+ def test_requester_is_no_admin(self):
+ """If the user is not a server admin, an error is returned."""
+ channel = self.make_request(
+ "POST",
+ self.url,
+ access_token=self.other_user_token,
+ )
+
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ @override_config({"server_notices": {"system_mxid_localpart": "notices"}})
+ def test_user_does_not_exist(self):
+ """Tests that a lookup for a user that does not exist returns a 404"""
+ channel = self.make_request(
+ "POST",
+ self.url,
+ access_token=self.admin_user_tok,
+ content={"user_id": "@unknown_person:test", "content": ""},
+ )
+
+ self.assertEqual(404, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
+
+ @override_config({"server_notices": {"system_mxid_localpart": "notices"}})
+ def test_user_is_not_local(self):
+ """
+ Tests that a lookup for a user that is not a local returns a 400
+ """
+ channel = self.make_request(
+ "POST",
+ self.url,
+ access_token=self.admin_user_tok,
+ content={
+ "user_id": "@unknown_person:unknown_domain",
+ "content": "",
+ },
+ )
+
+ self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual(
+ "Server notices can only be sent to local users", channel.json_body["error"]
+ )
+
+ @override_config({"server_notices": {"system_mxid_localpart": "notices"}})
+ def test_invalid_parameter(self):
+ """If parameters are invalid, an error is returned."""
+
+ # no content, no user
+ channel = self.make_request(
+ "POST",
+ self.url,
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.NOT_JSON, channel.json_body["errcode"])
+
+ # no content
+ channel = self.make_request(
+ "POST",
+ self.url,
+ access_token=self.admin_user_tok,
+ content={"user_id": self.other_user},
+ )
+
+ self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"])
+
+ # no body
+ channel = self.make_request(
+ "POST",
+ self.url,
+ access_token=self.admin_user_tok,
+ content={"user_id": self.other_user, "content": ""},
+ )
+
+ self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
+ self.assertEqual("'body' not in content", channel.json_body["error"])
+
+ # no msgtype
+ channel = self.make_request(
+ "POST",
+ self.url,
+ access_token=self.admin_user_tok,
+ content={"user_id": self.other_user, "content": {"body": ""}},
+ )
+
+ self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
+ self.assertEqual("'msgtype' not in content", channel.json_body["error"])
+
+ def test_server_notice_disabled(self):
+ """Tests that server returns error if server notice is disabled"""
+ channel = self.make_request(
+ "POST",
+ self.url,
+ access_token=self.admin_user_tok,
+ content={
+ "user_id": self.other_user,
+ "content": "",
+ },
+ )
+
+ self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
+ self.assertEqual(
+ "Server notices are not enabled on this server", channel.json_body["error"]
+ )
+
+ @override_config({"server_notices": {"system_mxid_localpart": "notices"}})
+ def test_send_server_notice(self):
+ """
+ Tests that sending two server notices is successfully,
+ the server uses the same room and do not send messages twice.
+ """
+ # user has no room memberships
+ self._check_invite_and_join_status(self.other_user, 0, 0)
+
+ # send first message
+ channel = self.make_request(
+ "POST",
+ self.url,
+ access_token=self.admin_user_tok,
+ content={
+ "user_id": self.other_user,
+ "content": {"msgtype": "m.text", "body": "test msg one"},
+ },
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+
+ # user has one invite
+ invited_rooms = self._check_invite_and_join_status(self.other_user, 1, 0)
+ room_id = invited_rooms[0].room_id
+
+ # user joins the room and is member now
+ self.helper.join(room=room_id, user=self.other_user, tok=self.other_user_token)
+ self._check_invite_and_join_status(self.other_user, 0, 1)
+
+ # get messages
+ messages = self._sync_and_get_messages(room_id, self.other_user_token)
+ self.assertEqual(len(messages), 1)
+ self.assertEqual(messages[0]["content"]["body"], "test msg one")
+ self.assertEqual(messages[0]["sender"], "@notices:test")
+
+ # invalidate cache of server notices room_ids
+ self.get_success(
+ self.server_notices_manager.get_or_create_notice_room_for_user.invalidate_all()
+ )
+
+ # send second message
+ channel = self.make_request(
+ "POST",
+ self.url,
+ access_token=self.admin_user_tok,
+ content={
+ "user_id": self.other_user,
+ "content": {"msgtype": "m.text", "body": "test msg two"},
+ },
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+
+ # user has no new invites or memberships
+ self._check_invite_and_join_status(self.other_user, 0, 1)
+
+ # get messages
+ messages = self._sync_and_get_messages(room_id, self.other_user_token)
+
+ self.assertEqual(len(messages), 2)
+ self.assertEqual(messages[0]["content"]["body"], "test msg one")
+ self.assertEqual(messages[0]["sender"], "@notices:test")
+ self.assertEqual(messages[1]["content"]["body"], "test msg two")
+ self.assertEqual(messages[1]["sender"], "@notices:test")
+
+ @override_config({"server_notices": {"system_mxid_localpart": "notices"}})
+ def test_send_server_notice_leave_room(self):
+ """
+ Tests that sending a server notices is successfully.
+ The user leaves the room and the second message appears
+ in a new room.
+ """
+ # user has no room memberships
+ self._check_invite_and_join_status(self.other_user, 0, 0)
+
+ # send first message
+ channel = self.make_request(
+ "POST",
+ self.url,
+ access_token=self.admin_user_tok,
+ content={
+ "user_id": self.other_user,
+ "content": {"msgtype": "m.text", "body": "test msg one"},
+ },
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+
+ # user has one invite
+ invited_rooms = self._check_invite_and_join_status(self.other_user, 1, 0)
+ first_room_id = invited_rooms[0].room_id
+
+ # user joins the room and is member now
+ self.helper.join(
+ room=first_room_id, user=self.other_user, tok=self.other_user_token
+ )
+ self._check_invite_and_join_status(self.other_user, 0, 1)
+
+ # get messages
+ messages = self._sync_and_get_messages(first_room_id, self.other_user_token)
+ self.assertEqual(len(messages), 1)
+ self.assertEqual(messages[0]["content"]["body"], "test msg one")
+ self.assertEqual(messages[0]["sender"], "@notices:test")
+
+ # user leaves the romm
+ self.helper.leave(
+ room=first_room_id, user=self.other_user, tok=self.other_user_token
+ )
+
+ # user is not member anymore
+ self._check_invite_and_join_status(self.other_user, 0, 0)
+
+ # invalidate cache of server notices room_ids
+ # if server tries to send to a cached room_id the user gets the message
+ # in old room
+ self.get_success(
+ self.server_notices_manager.get_or_create_notice_room_for_user.invalidate_all()
+ )
+
+ # send second message
+ channel = self.make_request(
+ "POST",
+ self.url,
+ access_token=self.admin_user_tok,
+ content={
+ "user_id": self.other_user,
+ "content": {"msgtype": "m.text", "body": "test msg two"},
+ },
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+
+ # user has one invite
+ invited_rooms = self._check_invite_and_join_status(self.other_user, 1, 0)
+ second_room_id = invited_rooms[0].room_id
+
+ # user joins the room and is member now
+ self.helper.join(
+ room=second_room_id, user=self.other_user, tok=self.other_user_token
+ )
+ self._check_invite_and_join_status(self.other_user, 0, 1)
+
+ # get messages
+ messages = self._sync_and_get_messages(second_room_id, self.other_user_token)
+
+ self.assertEqual(len(messages), 1)
+ self.assertEqual(messages[0]["content"]["body"], "test msg two")
+ self.assertEqual(messages[0]["sender"], "@notices:test")
+ # room has the same id
+ self.assertNotEqual(first_room_id, second_room_id)
+
+ @override_config({"server_notices": {"system_mxid_localpart": "notices"}})
+ def test_send_server_notice_delete_room(self):
+ """
+ Tests that the user get server notice in a new room
+ after the first server notice room was deleted.
+ """
+ # user has no room memberships
+ self._check_invite_and_join_status(self.other_user, 0, 0)
+
+ # send first message
+ channel = self.make_request(
+ "POST",
+ self.url,
+ access_token=self.admin_user_tok,
+ content={
+ "user_id": self.other_user,
+ "content": {"msgtype": "m.text", "body": "test msg one"},
+ },
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+
+ # user has one invite
+ invited_rooms = self._check_invite_and_join_status(self.other_user, 1, 0)
+ first_room_id = invited_rooms[0].room_id
+
+ # user joins the room and is member now
+ self.helper.join(
+ room=first_room_id, user=self.other_user, tok=self.other_user_token
+ )
+ self._check_invite_and_join_status(self.other_user, 0, 1)
+
+ # get messages
+ messages = self._sync_and_get_messages(first_room_id, self.other_user_token)
+ self.assertEqual(len(messages), 1)
+ self.assertEqual(messages[0]["content"]["body"], "test msg one")
+ self.assertEqual(messages[0]["sender"], "@notices:test")
+
+ # shut down and purge room
+ self.get_success(
+ self.room_shutdown_handler.shutdown_room(first_room_id, self.admin_user)
+ )
+ self.get_success(self.pagination_handler.purge_room(first_room_id))
+
+ # user is not member anymore
+ self._check_invite_and_join_status(self.other_user, 0, 0)
+
+ # It doesn't really matter what API we use here, we just want to assert
+ # that the room doesn't exist.
+ summary = self.get_success(self.store.get_room_summary(first_room_id))
+ # The summary should be empty since the room doesn't exist.
+ self.assertEqual(summary, {})
+
+ # invalidate cache of server notices room_ids
+ # if server tries to send to a cached room_id it gives an error
+ self.get_success(
+ self.server_notices_manager.get_or_create_notice_room_for_user.invalidate_all()
+ )
+
+ # send second message
+ channel = self.make_request(
+ "POST",
+ self.url,
+ access_token=self.admin_user_tok,
+ content={
+ "user_id": self.other_user,
+ "content": {"msgtype": "m.text", "body": "test msg two"},
+ },
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+
+ # user has one invite
+ invited_rooms = self._check_invite_and_join_status(self.other_user, 1, 0)
+ second_room_id = invited_rooms[0].room_id
+
+ # user joins the room and is member now
+ self.helper.join(
+ room=second_room_id, user=self.other_user, tok=self.other_user_token
+ )
+ self._check_invite_and_join_status(self.other_user, 0, 1)
+
+ # get message
+ messages = self._sync_and_get_messages(second_room_id, self.other_user_token)
+
+ self.assertEqual(len(messages), 1)
+ self.assertEqual(messages[0]["content"]["body"], "test msg two")
+ self.assertEqual(messages[0]["sender"], "@notices:test")
+ # second room has new ID
+ self.assertNotEqual(first_room_id, second_room_id)
+
+ def _check_invite_and_join_status(
+ self, user_id: str, expected_invites: int, expected_memberships: int
+ ) -> RoomsForUser:
+ """Check invite and room membership status of a user.
+
+ Args
+ user_id: user to check
+ expected_invites: number of expected invites of this user
+ expected_memberships: number of expected room memberships of this user
+ Returns
+ room_ids from the rooms that the user is invited
+ """
+
+ invited_rooms = self.get_success(
+ self.store.get_invited_rooms_for_local_user(user_id)
+ )
+ self.assertEqual(expected_invites, len(invited_rooms))
+
+ room_ids = self.get_success(self.store.get_rooms_for_user(user_id))
+ self.assertEqual(expected_memberships, len(room_ids))
+
+ return invited_rooms
+
+ def _sync_and_get_messages(self, room_id: str, token: str) -> List[JsonDict]:
+ """
+ Do a sync and get messages of a room.
+
+ Args
+ room_id: room that contains the messages
+ token: access token of user
+
+ Returns
+ list of messages contained in the room
+ """
+ channel = self.make_request(
+ "GET", "/_matrix/client/r0/sync", access_token=token
+ )
+ self.assertEqual(channel.code, 200)
+
+ # Get the messages
+ room = channel.json_body["rooms"]["join"][room_id]
+ messages = [
+ x for x in room["timeline"]["events"] if x["type"] == "m.room.message"
+ ]
+ return messages
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index ef7727523870..ee204c404b12 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -1431,12 +1431,14 @@ def test_create_user(self):
self.assertEqual("Bob's name", channel.json_body["displayname"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
+ self.assertEqual(1, len(channel.json_body["threepids"]))
self.assertEqual(
"external_id1", channel.json_body["external_ids"][0]["external_id"]
)
self.assertEqual(
"auth_provider1", channel.json_body["external_ids"][0]["auth_provider"]
)
+ self.assertEqual(1, len(channel.json_body["external_ids"]))
self.assertFalse(channel.json_body["admin"])
self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"])
self._check_fields(channel.json_body)
@@ -1676,18 +1678,53 @@ def test_set_threepid(self):
Test setting threepid for an other user.
"""
- # Delete old and add new threepid to user
+ # Add two threepids to user
channel = self.make_request(
"PUT",
self.url_other_user,
access_token=self.admin_user_tok,
- content={"threepids": [{"medium": "email", "address": "bob3@bob.bob"}]},
+ content={
+ "threepids": [
+ {"medium": "email", "address": "bob1@bob.bob"},
+ {"medium": "email", "address": "bob2@bob.bob"},
+ ],
+ },
)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
+ self.assertEqual(2, len(channel.json_body["threepids"]))
+ # result does not always have the same sort order, therefore it becomes sorted
+ sorted_result = sorted(
+ channel.json_body["threepids"], key=lambda k: k["address"]
+ )
+ self.assertEqual("email", sorted_result[0]["medium"])
+ self.assertEqual("bob1@bob.bob", sorted_result[0]["address"])
+ self.assertEqual("email", sorted_result[1]["medium"])
+ self.assertEqual("bob2@bob.bob", sorted_result[1]["address"])
+ self._check_fields(channel.json_body)
+
+ # Set a new and remove a threepid
+ channel = self.make_request(
+ "PUT",
+ self.url_other_user,
+ access_token=self.admin_user_tok,
+ content={
+ "threepids": [
+ {"medium": "email", "address": "bob2@bob.bob"},
+ {"medium": "email", "address": "bob3@bob.bob"},
+ ],
+ },
+ )
+
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual("@user:test", channel.json_body["name"])
+ self.assertEqual(2, len(channel.json_body["threepids"]))
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
- self.assertEqual("bob3@bob.bob", channel.json_body["threepids"][0]["address"])
+ self.assertEqual("bob2@bob.bob", channel.json_body["threepids"][0]["address"])
+ self.assertEqual("email", channel.json_body["threepids"][1]["medium"])
+ self.assertEqual("bob3@bob.bob", channel.json_body["threepids"][1]["address"])
+ self._check_fields(channel.json_body)
# Get user
channel = self.make_request(
@@ -1698,8 +1735,24 @@ def test_set_threepid(self):
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
+ self.assertEqual(2, len(channel.json_body["threepids"]))
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
- self.assertEqual("bob3@bob.bob", channel.json_body["threepids"][0]["address"])
+ self.assertEqual("bob2@bob.bob", channel.json_body["threepids"][0]["address"])
+ self.assertEqual("email", channel.json_body["threepids"][1]["medium"])
+ self.assertEqual("bob3@bob.bob", channel.json_body["threepids"][1]["address"])
+ self._check_fields(channel.json_body)
+
+ # Remove threepids
+ channel = self.make_request(
+ "PUT",
+ self.url_other_user,
+ access_token=self.admin_user_tok,
+ content={"threepids": []},
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual("@user:test", channel.json_body["name"])
+ self.assertEqual(0, len(channel.json_body["threepids"]))
+ self._check_fields(channel.json_body)
def test_set_external_id(self):
"""
@@ -1778,6 +1831,7 @@ def test_set_external_id(self):
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
+ self.assertEqual(2, len(channel.json_body["external_ids"]))
self.assertEqual(
channel.json_body["external_ids"],
[
diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/test_account.py
similarity index 100%
rename from tests/rest/client/v2_alpha/test_account.py
rename to tests/rest/client/test_account.py
diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/test_auth.py
similarity index 99%
rename from tests/rest/client/v2_alpha/test_auth.py
rename to tests/rest/client/test_auth.py
index cf5cfb910c8c..e2fcbdc63ac6 100644
--- a/tests/rest/client/v2_alpha/test_auth.py
+++ b/tests/rest/client/test_auth.py
@@ -25,7 +25,7 @@
from tests import unittest
from tests.handlers.test_oidc import HAS_OIDC
-from tests.rest.client.v1.utils import TEST_OIDC_CONFIG
+from tests.rest.client.utils import TEST_OIDC_CONFIG
from tests.server import FakeChannel
from tests.unittest import override_config, skip_unless
diff --git a/tests/rest/client/v2_alpha/test_capabilities.py b/tests/rest/client/test_capabilities.py
similarity index 54%
rename from tests/rest/client/v2_alpha/test_capabilities.py
rename to tests/rest/client/test_capabilities.py
index 13b3c5f499b7..422361b62a5a 100644
--- a/tests/rest/client/v2_alpha/test_capabilities.py
+++ b/tests/rest/client/test_capabilities.py
@@ -30,19 +30,22 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
self.url = b"/_matrix/client/r0/capabilities"
hs = self.setup_test_homeserver()
- self.store = hs.get_datastore()
self.config = hs.config
self.auth_handler = hs.get_auth_handler()
return hs
+ def prepare(self, reactor, clock, hs):
+ self.localpart = "user"
+ self.password = "pass"
+ self.user = self.register_user(self.localpart, self.password)
+
def test_check_auth_required(self):
channel = self.make_request("GET", self.url)
self.assertEqual(channel.code, 401)
def test_get_room_version_capabilities(self):
- self.register_user("user", "pass")
- access_token = self.login("user", "pass")
+ access_token = self.login(self.localpart, self.password)
channel = self.make_request("GET", self.url, access_token=access_token)
capabilities = channel.json_body["capabilities"]
@@ -57,10 +60,7 @@ def test_get_room_version_capabilities(self):
)
def test_get_change_password_capabilities_password_login(self):
- localpart = "user"
- password = "pass"
- user = self.register_user(localpart, password)
- access_token = self.login(user, password)
+ access_token = self.login(self.localpart, self.password)
channel = self.make_request("GET", self.url, access_token=access_token)
capabilities = channel.json_body["capabilities"]
@@ -70,12 +70,9 @@ def test_get_change_password_capabilities_password_login(self):
@override_config({"password_config": {"localdb_enabled": False}})
def test_get_change_password_capabilities_localdb_disabled(self):
- localpart = "user"
- password = "pass"
- user = self.register_user(localpart, password)
access_token = self.get_success(
self.auth_handler.get_access_token_for_user_id(
- user, device_id=None, valid_until_ms=None
+ self.user, device_id=None, valid_until_ms=None
)
)
@@ -87,12 +84,9 @@ def test_get_change_password_capabilities_localdb_disabled(self):
@override_config({"password_config": {"enabled": False}})
def test_get_change_password_capabilities_password_disabled(self):
- localpart = "user"
- password = "pass"
- user = self.register_user(localpart, password)
access_token = self.get_success(
self.auth_handler.get_access_token_for_user_id(
- user, device_id=None, valid_until_ms=None
+ self.user, device_id=None, valid_until_ms=None
)
)
@@ -102,14 +96,86 @@ def test_get_change_password_capabilities_password_disabled(self):
self.assertEqual(channel.code, 200)
self.assertFalse(capabilities["m.change_password"]["enabled"])
+ def test_get_change_users_attributes_capabilities_when_msc3283_disabled(self):
+ """Test that per default msc3283 is disabled server returns `m.change_password`."""
+ access_token = self.login(self.localpart, self.password)
+
+ channel = self.make_request("GET", self.url, access_token=access_token)
+ capabilities = channel.json_body["capabilities"]
+
+ self.assertEqual(channel.code, 200)
+ self.assertTrue(capabilities["m.change_password"]["enabled"])
+ self.assertNotIn("org.matrix.msc3283.set_displayname", capabilities)
+ self.assertNotIn("org.matrix.msc3283.set_avatar_url", capabilities)
+ self.assertNotIn("org.matrix.msc3283.3pid_changes", capabilities)
+
+ @override_config({"experimental_features": {"msc3283_enabled": True}})
+ def test_get_change_users_attributes_capabilities_when_msc3283_enabled(self):
+ """Test if msc3283 is enabled server returns capabilities."""
+ access_token = self.login(self.localpart, self.password)
+
+ channel = self.make_request("GET", self.url, access_token=access_token)
+ capabilities = channel.json_body["capabilities"]
+
+ self.assertEqual(channel.code, 200)
+ self.assertTrue(capabilities["m.change_password"]["enabled"])
+ self.assertTrue(capabilities["org.matrix.msc3283.set_displayname"]["enabled"])
+ self.assertTrue(capabilities["org.matrix.msc3283.set_avatar_url"]["enabled"])
+ self.assertTrue(capabilities["org.matrix.msc3283.3pid_changes"]["enabled"])
+
+ @override_config(
+ {
+ "enable_set_displayname": False,
+ "experimental_features": {"msc3283_enabled": True},
+ }
+ )
+ def test_get_set_displayname_capabilities_displayname_disabled(self):
+ """Test if set displayname is disabled that the server responds it."""
+ access_token = self.login(self.localpart, self.password)
+
+ channel = self.make_request("GET", self.url, access_token=access_token)
+ capabilities = channel.json_body["capabilities"]
+
+ self.assertEqual(channel.code, 200)
+ self.assertFalse(capabilities["org.matrix.msc3283.set_displayname"]["enabled"])
+
+ @override_config(
+ {
+ "enable_set_avatar_url": False,
+ "experimental_features": {"msc3283_enabled": True},
+ }
+ )
+ def test_get_set_avatar_url_capabilities_avatar_url_disabled(self):
+ """Test if set avatar_url is disabled that the server responds it."""
+ access_token = self.login(self.localpart, self.password)
+
+ channel = self.make_request("GET", self.url, access_token=access_token)
+ capabilities = channel.json_body["capabilities"]
+
+ self.assertEqual(channel.code, 200)
+ self.assertFalse(capabilities["org.matrix.msc3283.set_avatar_url"]["enabled"])
+
+ @override_config(
+ {
+ "enable_3pid_changes": False,
+ "experimental_features": {"msc3283_enabled": True},
+ }
+ )
+ def test_change_3pid_capabilities_3pid_disabled(self):
+ """Test if change 3pid is disabled that the server responds it."""
+ access_token = self.login(self.localpart, self.password)
+
+ channel = self.make_request("GET", self.url, access_token=access_token)
+ capabilities = channel.json_body["capabilities"]
+
+ self.assertEqual(channel.code, 200)
+ self.assertFalse(capabilities["org.matrix.msc3283.3pid_changes"]["enabled"])
+
@override_config({"experimental_features": {"msc3244_enabled": False}})
def test_get_does_not_include_msc3244_fields_when_disabled(self):
- localpart = "user"
- password = "pass"
- user = self.register_user(localpart, password)
access_token = self.get_success(
self.auth_handler.get_access_token_for_user_id(
- user, device_id=None, valid_until_ms=None
+ self.user, device_id=None, valid_until_ms=None
)
)
@@ -122,12 +188,9 @@ def test_get_does_not_include_msc3244_fields_when_disabled(self):
)
def test_get_does_include_msc3244_fields_when_enabled(self):
- localpart = "user"
- password = "pass"
- user = self.register_user(localpart, password)
access_token = self.get_success(
self.auth_handler.get_access_token_for_user_id(
- user, device_id=None, valid_until_ms=None
+ self.user, device_id=None, valid_until_ms=None
)
)
diff --git a/tests/rest/client/v1/test_directory.py b/tests/rest/client/test_directory.py
similarity index 100%
rename from tests/rest/client/v1/test_directory.py
rename to tests/rest/client/test_directory.py
diff --git a/tests/rest/client/v1/test_events.py b/tests/rest/client/test_events.py
similarity index 100%
rename from tests/rest/client/v1/test_events.py
rename to tests/rest/client/test_events.py
diff --git a/tests/rest/client/v2_alpha/test_filter.py b/tests/rest/client/test_filter.py
similarity index 100%
rename from tests/rest/client/v2_alpha/test_filter.py
rename to tests/rest/client/test_filter.py
diff --git a/tests/rest/client/v2_alpha/test_groups.py b/tests/rest/client/test_groups.py
similarity index 100%
rename from tests/rest/client/v2_alpha/test_groups.py
rename to tests/rest/client/test_groups.py
diff --git a/tests/rest/client/test_keys.py b/tests/rest/client/test_keys.py
new file mode 100644
index 000000000000..d7fa635eae1d
--- /dev/null
+++ b/tests/rest/client/test_keys.py
@@ -0,0 +1,91 @@
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License
+
+from http import HTTPStatus
+
+from synapse.api.errors import Codes
+from synapse.rest import admin
+from synapse.rest.client import keys, login
+
+from tests import unittest
+
+
+class KeyQueryTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ keys.register_servlets,
+ admin.register_servlets_for_client_rest_resource,
+ login.register_servlets,
+ ]
+
+ def test_rejects_device_id_ice_key_outside_of_list(self):
+ self.register_user("alice", "wonderland")
+ alice_token = self.login("alice", "wonderland")
+ bob = self.register_user("bob", "uncle")
+ channel = self.make_request(
+ "POST",
+ "/_matrix/client/r0/keys/query",
+ {
+ "device_keys": {
+ bob: "device_id1",
+ },
+ },
+ alice_token,
+ )
+ self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
+ self.assertEqual(
+ channel.json_body["errcode"],
+ Codes.BAD_JSON,
+ channel.result,
+ )
+
+ def test_rejects_device_key_given_as_map_to_bool(self):
+ self.register_user("alice", "wonderland")
+ alice_token = self.login("alice", "wonderland")
+ bob = self.register_user("bob", "uncle")
+ channel = self.make_request(
+ "POST",
+ "/_matrix/client/r0/keys/query",
+ {
+ "device_keys": {
+ bob: {
+ "device_id1": True,
+ },
+ },
+ },
+ alice_token,
+ )
+
+ self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
+ self.assertEqual(
+ channel.json_body["errcode"],
+ Codes.BAD_JSON,
+ channel.result,
+ )
+
+ def test_requires_device_key(self):
+ """`device_keys` is required. We should complain if it's missing."""
+ self.register_user("alice", "wonderland")
+ alice_token = self.login("alice", "wonderland")
+ channel = self.make_request(
+ "POST",
+ "/_matrix/client/r0/keys/query",
+ {},
+ alice_token,
+ )
+ self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
+ self.assertEqual(
+ channel.json_body["errcode"],
+ Codes.BAD_JSON,
+ channel.result,
+ )
diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/test_login.py
similarity index 99%
rename from tests/rest/client/v1/test_login.py
rename to tests/rest/client/test_login.py
index eba3552b19ac..5b2243fe5205 100644
--- a/tests/rest/client/v1/test_login.py
+++ b/tests/rest/client/test_login.py
@@ -32,7 +32,7 @@
from tests import unittest
from tests.handlers.test_oidc import HAS_OIDC
from tests.handlers.test_saml import has_saml2
-from tests.rest.client.v1.utils import TEST_OIDC_AUTH_ENDPOINT, TEST_OIDC_CONFIG
+from tests.rest.client.utils import TEST_OIDC_AUTH_ENDPOINT, TEST_OIDC_CONFIG
from tests.test_utils.html_parsers import TestHtmlParser
from tests.unittest import HomeserverTestCase, override_config, skip_unless
diff --git a/tests/rest/client/v2_alpha/test_password_policy.py b/tests/rest/client/test_password_policy.py
similarity index 100%
rename from tests/rest/client/v2_alpha/test_password_policy.py
rename to tests/rest/client/test_password_policy.py
diff --git a/tests/rest/client/test_power_levels.py b/tests/rest/client/test_power_levels.py
index 91d0762cb0ab..c0de4c93a806 100644
--- a/tests/rest/client/test_power_levels.py
+++ b/tests/rest/client/test_power_levels.py
@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from synapse.api.errors import Codes
+from synapse.events.utils import CANONICALJSON_MAX_INT, CANONICALJSON_MIN_INT
from synapse.rest import admin
from synapse.rest.client import login, room, sync
@@ -203,3 +205,79 @@ def test_admins_can_tombstone_room(self):
tok=self.admin_access_token,
expect_code=200, # expect success
)
+
+ def test_cannot_set_string_power_levels(self):
+ room_power_levels = self.helper.get_state(
+ self.room_id,
+ "m.room.power_levels",
+ tok=self.admin_access_token,
+ )
+
+ # Update existing power levels with user at PL "0"
+ room_power_levels["users"].update({self.user_user_id: "0"})
+
+ body = self.helper.send_state(
+ self.room_id,
+ "m.room.power_levels",
+ room_power_levels,
+ tok=self.admin_access_token,
+ expect_code=400, # expect failure
+ )
+
+ self.assertEqual(
+ body["errcode"],
+ Codes.BAD_JSON,
+ body,
+ )
+
+ def test_cannot_set_unsafe_large_power_levels(self):
+ room_power_levels = self.helper.get_state(
+ self.room_id,
+ "m.room.power_levels",
+ tok=self.admin_access_token,
+ )
+
+ # Update existing power levels with user at PL above the max safe integer
+ room_power_levels["users"].update(
+ {self.user_user_id: CANONICALJSON_MAX_INT + 1}
+ )
+
+ body = self.helper.send_state(
+ self.room_id,
+ "m.room.power_levels",
+ room_power_levels,
+ tok=self.admin_access_token,
+ expect_code=400, # expect failure
+ )
+
+ self.assertEqual(
+ body["errcode"],
+ Codes.BAD_JSON,
+ body,
+ )
+
+ def test_cannot_set_unsafe_small_power_levels(self):
+ room_power_levels = self.helper.get_state(
+ self.room_id,
+ "m.room.power_levels",
+ tok=self.admin_access_token,
+ )
+
+ # Update existing power levels with user at PL below the minimum safe integer
+ room_power_levels["users"].update(
+ {self.user_user_id: CANONICALJSON_MIN_INT - 1}
+ )
+
+ body = self.helper.send_state(
+ self.room_id,
+ "m.room.power_levels",
+ room_power_levels,
+ tok=self.admin_access_token,
+ expect_code=400, # expect failure
+ )
+
+ self.assertEqual(
+ body["errcode"],
+ Codes.BAD_JSON,
+ body,
+ )
diff --git a/tests/rest/client/v1/test_presence.py b/tests/rest/client/test_presence.py
similarity index 100%
rename from tests/rest/client/v1/test_presence.py
rename to tests/rest/client/test_presence.py
diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/test_profile.py
similarity index 100%
rename from tests/rest/client/v1/test_profile.py
rename to tests/rest/client/test_profile.py
diff --git a/tests/rest/client/v1/test_push_rule_attrs.py b/tests/rest/client/test_push_rule_attrs.py
similarity index 100%
rename from tests/rest/client/v1/test_push_rule_attrs.py
rename to tests/rest/client/test_push_rule_attrs.py
diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/test_register.py
similarity index 63%
rename from tests/rest/client/v2_alpha/test_register.py
rename to tests/rest/client/test_register.py
index fecda037a54b..9f3ab2c9858a 100644
--- a/tests/rest/client/v2_alpha/test_register.py
+++ b/tests/rest/client/test_register.py
@@ -24,6 +24,7 @@
from synapse.api.errors import Codes
from synapse.appservice import ApplicationService
from synapse.rest.client import account, account_validity, login, logout, register, sync
+from synapse.storage._base import db_to_json
from tests import unittest
from tests.unittest import override_config
@@ -204,6 +205,371 @@ def test_POST_ratelimiting(self):
self.assertEquals(channel.result["code"], b"200", channel.result)
+ @override_config({"registration_requires_token": True})
+ def test_POST_registration_requires_token(self):
+ username = "kermit"
+ device_id = "frogfone"
+ token = "abcd"
+ store = self.hs.get_datastore()
+ self.get_success(
+ store.db_pool.simple_insert(
+ "registration_tokens",
+ {
+ "token": token,
+ "uses_allowed": None,
+ "pending": 0,
+ "completed": 0,
+ "expiry_time": None,
+ },
+ )
+ )
+ params = {
+ "username": username,
+ "password": "monkey",
+ "device_id": device_id,
+ }
+
+ # Request without auth to get flows and session
+ channel = self.make_request(b"POST", self.url, json.dumps(params))
+ self.assertEquals(channel.result["code"], b"401", channel.result)
+ flows = channel.json_body["flows"]
+ # Synapse adds a dummy stage to differentiate flows where otherwise one
+ # flow would be a subset of another flow.
+ self.assertCountEqual(
+ [[LoginType.REGISTRATION_TOKEN, LoginType.DUMMY]],
+ (f["stages"] for f in flows),
+ )
+ session = channel.json_body["session"]
+
+ # Do the registration token stage and check it has completed
+ params["auth"] = {
+ "type": LoginType.REGISTRATION_TOKEN,
+ "token": token,
+ "session": session,
+ }
+ request_data = json.dumps(params)
+ channel = self.make_request(b"POST", self.url, request_data)
+ self.assertEquals(channel.result["code"], b"401", channel.result)
+ completed = channel.json_body["completed"]
+ self.assertCountEqual([LoginType.REGISTRATION_TOKEN], completed)
+
+ # Do the m.login.dummy stage and check registration was successful
+ params["auth"] = {
+ "type": LoginType.DUMMY,
+ "session": session,
+ }
+ request_data = json.dumps(params)
+ channel = self.make_request(b"POST", self.url, request_data)
+ det_data = {
+ "user_id": f"@{username}:{self.hs.hostname}",
+ "home_server": self.hs.hostname,
+ "device_id": device_id,
+ }
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+ self.assertDictContainsSubset(det_data, channel.json_body)
+
+ # Check the `completed` counter has been incremented and pending is 0
+ res = self.get_success(
+ store.db_pool.simple_select_one(
+ "registration_tokens",
+ keyvalues={"token": token},
+ retcols=["pending", "completed"],
+ )
+ )
+ self.assertEquals(res["completed"], 1)
+ self.assertEquals(res["pending"], 0)
+
+ @override_config({"registration_requires_token": True})
+ def test_POST_registration_token_invalid(self):
+ params = {
+ "username": "kermit",
+ "password": "monkey",
+ }
+ # Request without auth to get session
+ channel = self.make_request(b"POST", self.url, json.dumps(params))
+ session = channel.json_body["session"]
+
+ # Test with token param missing (invalid)
+ params["auth"] = {
+ "type": LoginType.REGISTRATION_TOKEN,
+ "session": session,
+ }
+ channel = self.make_request(b"POST", self.url, json.dumps(params))
+ self.assertEquals(channel.result["code"], b"401", channel.result)
+ self.assertEquals(channel.json_body["errcode"], Codes.MISSING_PARAM)
+ self.assertEquals(channel.json_body["completed"], [])
+
+ # Test with non-string (invalid)
+ params["auth"]["token"] = 1234
+ channel = self.make_request(b"POST", self.url, json.dumps(params))
+ self.assertEquals(channel.result["code"], b"401", channel.result)
+ self.assertEquals(channel.json_body["errcode"], Codes.INVALID_PARAM)
+ self.assertEquals(channel.json_body["completed"], [])
+
+ # Test with unknown token (invalid)
+ params["auth"]["token"] = "1234"
+ channel = self.make_request(b"POST", self.url, json.dumps(params))
+ self.assertEquals(channel.result["code"], b"401", channel.result)
+ self.assertEquals(channel.json_body["errcode"], Codes.UNAUTHORIZED)
+ self.assertEquals(channel.json_body["completed"], [])
+
+ @override_config({"registration_requires_token": True})
+ def test_POST_registration_token_limit_uses(self):
+ token = "abcd"
+ store = self.hs.get_datastore()
+ # Create token that can be used once
+ self.get_success(
+ store.db_pool.simple_insert(
+ "registration_tokens",
+ {
+ "token": token,
+ "uses_allowed": 1,
+ "pending": 0,
+ "completed": 0,
+ "expiry_time": None,
+ },
+ )
+ )
+ params1 = {"username": "bert", "password": "monkey"}
+ params2 = {"username": "ernie", "password": "monkey"}
+ # Do 2 requests without auth to get two session IDs
+ channel1 = self.make_request(b"POST", self.url, json.dumps(params1))
+ session1 = channel1.json_body["session"]
+ channel2 = self.make_request(b"POST", self.url, json.dumps(params2))
+ session2 = channel2.json_body["session"]
+
+ # Use token with session1 and check `pending` is 1
+ params1["auth"] = {
+ "type": LoginType.REGISTRATION_TOKEN,
+ "token": token,
+ "session": session1,
+ }
+ self.make_request(b"POST", self.url, json.dumps(params1))
+ # Repeat request to make sure pending isn't increased again
+ self.make_request(b"POST", self.url, json.dumps(params1))
+ pending = self.get_success(
+ store.db_pool.simple_select_one_onecol(
+ "registration_tokens",
+ keyvalues={"token": token},
+ retcol="pending",
+ )
+ )
+ self.assertEquals(pending, 1)
+
+ # Check auth fails when using token with session2
+ params2["auth"] = {
+ "type": LoginType.REGISTRATION_TOKEN,
+ "token": token,
+ "session": session2,
+ }
+ channel = self.make_request(b"POST", self.url, json.dumps(params2))
+ self.assertEquals(channel.result["code"], b"401", channel.result)
+ self.assertEquals(channel.json_body["errcode"], Codes.UNAUTHORIZED)
+ self.assertEquals(channel.json_body["completed"], [])
+
+ # Complete registration with session1
+ params1["auth"]["type"] = LoginType.DUMMY
+ self.make_request(b"POST", self.url, json.dumps(params1))
+ # Check pending=0 and completed=1
+ res = self.get_success(
+ store.db_pool.simple_select_one(
+ "registration_tokens",
+ keyvalues={"token": token},
+ retcols=["pending", "completed"],
+ )
+ )
+ self.assertEquals(res["pending"], 0)
+ self.assertEquals(res["completed"], 1)
+
+ # Check auth still fails when using token with session2
+ channel = self.make_request(b"POST", self.url, json.dumps(params2))
+ self.assertEquals(channel.result["code"], b"401", channel.result)
+ self.assertEquals(channel.json_body["errcode"], Codes.UNAUTHORIZED)
+ self.assertEquals(channel.json_body["completed"], [])
+
+ @override_config({"registration_requires_token": True})
+ def test_POST_registration_token_expiry(self):
+ token = "abcd"
+ now = self.hs.get_clock().time_msec()
+ store = self.hs.get_datastore()
+ # Create token that expired yesterday
+ self.get_success(
+ store.db_pool.simple_insert(
+ "registration_tokens",
+ {
+ "token": token,
+ "uses_allowed": None,
+ "pending": 0,
+ "completed": 0,
+ "expiry_time": now - 24 * 60 * 60 * 1000,
+ },
+ )
+ )
+ params = {"username": "kermit", "password": "monkey"}
+ # Request without auth to get session
+ channel = self.make_request(b"POST", self.url, json.dumps(params))
+ session = channel.json_body["session"]
+
+ # Check authentication fails with expired token
+ params["auth"] = {
+ "type": LoginType.REGISTRATION_TOKEN,
+ "token": token,
+ "session": session,
+ }
+ channel = self.make_request(b"POST", self.url, json.dumps(params))
+ self.assertEquals(channel.result["code"], b"401", channel.result)
+ self.assertEquals(channel.json_body["errcode"], Codes.UNAUTHORIZED)
+ self.assertEquals(channel.json_body["completed"], [])
+
+ # Update token so it expires tomorrow
+ self.get_success(
+ store.db_pool.simple_update_one(
+ "registration_tokens",
+ keyvalues={"token": token},
+ updatevalues={"expiry_time": now + 24 * 60 * 60 * 1000},
+ )
+ )
+
+ # Check authentication succeeds
+ channel = self.make_request(b"POST", self.url, json.dumps(params))
+ completed = channel.json_body["completed"]
+ self.assertCountEqual([LoginType.REGISTRATION_TOKEN], completed)
+
+ @override_config({"registration_requires_token": True})
+ def test_POST_registration_token_session_expiry(self):
+ """Test `pending` is decremented when an uncompleted session expires."""
+ token = "abcd"
+ store = self.hs.get_datastore()
+ self.get_success(
+ store.db_pool.simple_insert(
+ "registration_tokens",
+ {
+ "token": token,
+ "uses_allowed": None,
+ "pending": 0,
+ "completed": 0,
+ "expiry_time": None,
+ },
+ )
+ )
+
+ # Do 2 requests without auth to get two session IDs
+ params1 = {"username": "bert", "password": "monkey"}
+ params2 = {"username": "ernie", "password": "monkey"}
+ channel1 = self.make_request(b"POST", self.url, json.dumps(params1))
+ session1 = channel1.json_body["session"]
+ channel2 = self.make_request(b"POST", self.url, json.dumps(params2))
+ session2 = channel2.json_body["session"]
+
+ # Use token with both sessions
+ params1["auth"] = {
+ "type": LoginType.REGISTRATION_TOKEN,
+ "token": token,
+ "session": session1,
+ }
+ self.make_request(b"POST", self.url, json.dumps(params1))
+
+ params2["auth"] = {
+ "type": LoginType.REGISTRATION_TOKEN,
+ "token": token,
+ "session": session2,
+ }
+ self.make_request(b"POST", self.url, json.dumps(params2))
+
+ # Complete registration with session1
+ params1["auth"]["type"] = LoginType.DUMMY
+ self.make_request(b"POST", self.url, json.dumps(params1))
+
+ # Check `result` of registration token stage for session1 is `True`
+ result1 = self.get_success(
+ store.db_pool.simple_select_one_onecol(
+ "ui_auth_sessions_credentials",
+ keyvalues={
+ "session_id": session1,
+ "stage_type": LoginType.REGISTRATION_TOKEN,
+ },
+ retcol="result",
+ )
+ )
+ self.assertTrue(db_to_json(result1))
+
+ # Check `result` for session2 is the token used
+ result2 = self.get_success(
+ store.db_pool.simple_select_one_onecol(
+ "ui_auth_sessions_credentials",
+ keyvalues={
+ "session_id": session2,
+ "stage_type": LoginType.REGISTRATION_TOKEN,
+ },
+ retcol="result",
+ )
+ )
+ self.assertEquals(db_to_json(result2), token)
+
+ # Delete both sessions (mimics expiry)
+ self.get_success(
+ store.delete_old_ui_auth_sessions(self.hs.get_clock().time_msec())
+ )
+
+ # Check pending is now 0
+ pending = self.get_success(
+ store.db_pool.simple_select_one_onecol(
+ "registration_tokens",
+ keyvalues={"token": token},
+ retcol="pending",
+ )
+ )
+ self.assertEquals(pending, 0)
+
+ @override_config({"registration_requires_token": True})
+ def test_POST_registration_token_session_expiry_deleted_token(self):
+ """Test session expiry doesn't break when the token is deleted.
+
+ 1. Start but don't complete UIA with a registration token
+ 2. Delete the token from the database
+ 3. Expire the session
+ """
+ token = "abcd"
+ store = self.hs.get_datastore()
+ self.get_success(
+ store.db_pool.simple_insert(
+ "registration_tokens",
+ {
+ "token": token,
+ "uses_allowed": None,
+ "pending": 0,
+ "completed": 0,
+ "expiry_time": None,
+ },
+ )
+ )
+
+ # Do request without auth to get a session ID
+ params = {"username": "kermit", "password": "monkey"}
+ channel = self.make_request(b"POST", self.url, json.dumps(params))
+ session = channel.json_body["session"]
+
+ # Use token
+ params["auth"] = {
+ "type": LoginType.REGISTRATION_TOKEN,
+ "token": token,
+ "session": session,
+ }
+ self.make_request(b"POST", self.url, json.dumps(params))
+
+ # Delete token
+ self.get_success(
+ store.db_pool.simple_delete_one(
+ "registration_tokens",
+ keyvalues={"token": token},
+ )
+ )
+
+ # Delete session (mimics expiry)
+ self.get_success(
+ store.delete_old_ui_auth_sessions(self.hs.get_clock().time_msec())
+ )
+
def test_advertised_flows(self):
channel = self.make_request(b"POST", self.url, b"{}")
self.assertEquals(channel.result["code"], b"401", channel.result)
@@ -744,3 +1110,71 @@ def test_background_job(self):
self.assertGreaterEqual(res, now_ms + self.validity_period - self.max_delta)
self.assertLessEqual(res, now_ms + self.validity_period)
+
+
+class RegistrationTokenValidityRestServletTestCase(unittest.HomeserverTestCase):
+ servlets = [register.register_servlets]
+ url = "/_matrix/client/unstable/org.matrix.msc3231/register/org.matrix.msc3231.login.registration_token/validity"
+
+ def default_config(self):
+ config = super().default_config()
+ config["registration_requires_token"] = True
+ return config
+
+ def test_GET_token_valid(self):
+ token = "abcd"
+ store = self.hs.get_datastore()
+ self.get_success(
+ store.db_pool.simple_insert(
+ "registration_tokens",
+ {
+ "token": token,
+ "uses_allowed": None,
+ "pending": 0,
+ "completed": 0,
+ "expiry_time": None,
+ },
+ )
+ )
+
+ channel = self.make_request(
+ b"GET",
+ f"{self.url}?token={token}",
+ )
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+ self.assertEquals(channel.json_body["valid"], True)
+
+ def test_GET_token_invalid(self):
+ token = "1234"
+ channel = self.make_request(
+ b"GET",
+ f"{self.url}?token={token}",
+ )
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+ self.assertEquals(channel.json_body["valid"], False)
+
+ @override_config(
+ {"rc_registration_token_validity": {"per_second": 0.1, "burst_count": 5}}
+ )
+ def test_GET_ratelimiting(self):
+ token = "1234"
+
+ for i in range(0, 6):
+ channel = self.make_request(
+ b"GET",
+ f"{self.url}?token={token}",
+ )
+
+ if i == 5:
+ self.assertEquals(channel.result["code"], b"429", channel.result)
+ retry_after_ms = int(channel.json_body["retry_after_ms"])
+ else:
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+
+ self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
+
+ channel = self.make_request(
+ b"GET",
+ f"{self.url}?token={token}",
+ )
+ self.assertEquals(channel.result["code"], b"200", channel.result)
diff --git a/tests/rest/client/v2_alpha/test_relations.py b/tests/rest/client/test_relations.py
similarity index 100%
rename from tests/rest/client/v2_alpha/test_relations.py
rename to tests/rest/client/test_relations.py
diff --git a/tests/rest/client/v2_alpha/test_report_event.py b/tests/rest/client/test_report_event.py
similarity index 100%
rename from tests/rest/client/v2_alpha/test_report_event.py
rename to tests/rest/client/test_report_event.py
diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/test_rooms.py
similarity index 100%
rename from tests/rest/client/v1/test_rooms.py
rename to tests/rest/client/test_rooms.py
diff --git a/tests/rest/client/v2_alpha/test_sendtodevice.py b/tests/rest/client/test_sendtodevice.py
similarity index 100%
rename from tests/rest/client/v2_alpha/test_sendtodevice.py
rename to tests/rest/client/test_sendtodevice.py
diff --git a/tests/rest/client/v2_alpha/test_shared_rooms.py b/tests/rest/client/test_shared_rooms.py
similarity index 100%
rename from tests/rest/client/v2_alpha/test_shared_rooms.py
rename to tests/rest/client/test_shared_rooms.py
diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/test_sync.py
similarity index 100%
rename from tests/rest/client/v2_alpha/test_sync.py
rename to tests/rest/client/test_sync.py
diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/test_typing.py
similarity index 100%
rename from tests/rest/client/v1/test_typing.py
rename to tests/rest/client/test_typing.py
diff --git a/tests/rest/client/v2_alpha/test_upgrade_room.py b/tests/rest/client/test_upgrade_room.py
similarity index 100%
rename from tests/rest/client/v2_alpha/test_upgrade_room.py
rename to tests/rest/client/test_upgrade_room.py
diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/utils.py
similarity index 100%
rename from tests/rest/client/v1/utils.py
rename to tests/rest/client/utils.py
diff --git a/tests/rest/client/v1/__init__.py b/tests/rest/client/v1/__init__.py
deleted file mode 100644
index 5e83dba2ed6f..000000000000
--- a/tests/rest/client/v1/__init__.py
+++ /dev/null
@@ -1,13 +0,0 @@
-# Copyright 2014-2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
diff --git a/tests/rest/client/v2_alpha/__init__.py b/tests/rest/client/v2_alpha/__init__.py
deleted file mode 100644
index e69de29bb2d1..000000000000
diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py
index 6085444b9da8..2f7eebfe6931 100644
--- a/tests/rest/media/v1/test_media_storage.py
+++ b/tests/rest/media/v1/test_media_storage.py
@@ -21,7 +21,7 @@
from urllib import parse
import attr
-from parameterized import parameterized_class
+from parameterized import parameterized, parameterized_class
from PIL import Image as Image
from twisted.internet import defer
@@ -473,6 +473,43 @@ def _test_thumbnail(self, method, expected_body, expected_found):
},
)
+ @parameterized.expand([("crop", 16), ("crop", 64), ("scale", 16), ("scale", 64)])
+ def test_same_quality(self, method, desired_size):
+ """Test that choosing between thumbnails with the same quality rating succeeds.
+
+ We are not particular about which thumbnail is chosen."""
+ self.assertIsNotNone(
+ self.thumbnail_resource._select_thumbnail(
+ desired_width=desired_size,
+ desired_height=desired_size,
+ desired_method=method,
+ desired_type=self.test_image.content_type,
+ # Provide two identical thumbnails which are guaranteed to have the same
+ # quality rating.
+ thumbnail_infos=[
+ {
+ "thumbnail_width": 32,
+ "thumbnail_height": 32,
+ "thumbnail_method": method,
+ "thumbnail_type": self.test_image.content_type,
+ "thumbnail_length": 256,
+ "filesystem_id": f"thumbnail1{self.test_image.extension}",
+ },
+ {
+ "thumbnail_width": 32,
+ "thumbnail_height": 32,
+ "thumbnail_method": method,
+ "thumbnail_type": self.test_image.content_type,
+ "thumbnail_length": 256,
+ "filesystem_id": f"thumbnail2{self.test_image.extension}",
+ },
+ ],
+ file_id=f"image{self.test_image.extension}",
+ url_cache=None,
+ server_name=None,
+ )
+ )
+
def test_x_robots_tag_header(self):
"""
Tests that the `X-Robots-Tag` header is present, which informs web crawlers
diff --git a/tests/test_federation.py b/tests/test_federation.py
index 3785799f46d2..61c9d7c2ef96 100644
--- a/tests/test_federation.py
+++ b/tests/test_federation.py
@@ -75,7 +75,8 @@ def setUp(self):
)
self.handler = self.homeserver.get_federation_handler()
- self.handler._check_event_auth = lambda origin, event, context, state, claimed_auth_event_map, backfilled: succeed(
+ federation_event_handler = self.homeserver.get_federation_event_handler()
+ federation_event_handler._check_event_auth = lambda origin, event, context, state, claimed_auth_event_map, backfilled: succeed(
context
)
self.client = self.homeserver.get_federation_client()
@@ -86,9 +87,7 @@ def setUp(self):
# Send the join, it should return None (which is not an error)
self.assertEqual(
self.get_success(
- self.handler.on_receive_pdu(
- "test.serv", join_event, sent_to_us_directly=True
- )
+ federation_event_handler.on_receive_pdu("test.serv", join_event)
),
None,
)
@@ -133,11 +132,10 @@ async def post_json(destination, path, data, headers=None, timeout=0):
}
)
+ federation_event_handler = self.homeserver.get_federation_event_handler()
with LoggingContext("test-context"):
failure = self.get_failure(
- self.handler.on_receive_pdu(
- "test.serv", lying_event, sent_to_us_directly=True
- ),
+ federation_event_handler.on_receive_pdu("test.serv", lying_event),
FederationError,
)
diff --git a/tests/unittest.py b/tests/unittest.py
index 3eec9c4d5b6f..f2c90cc47b53 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -252,7 +252,7 @@ def setUp(self):
reactor=self.reactor,
)
- from tests.rest.client.v1.utils import RestHelper
+ from tests.rest.client.utils import RestHelper
self.helper = RestHelper(self.hs, self.site, getattr(self, "user_id", None))