From dac97642e41f3f4bc0deff0c80b6a3f7acb4dbc0 Mon Sep 17 00:00:00 2001
From: Mathieu Velten <mathieuv@matrix.org>
Date: Thu, 10 Aug 2023 11:10:55 +0200
Subject: [PATCH 1/4] Implements admin API to lock an user (MSC3939) (#15870)

---
 changelog.d/15870.feature                     |   1 +
 docs/admin_api/user_admin_api.md              |   1 +
 .../configuration/config_documentation.md     |   2 +
 synapse/_scripts/synapse_port_db.py           |   2 +-
 synapse/api/auth/__init__.py                  |   1 +
 synapse/api/auth/internal.py                  |  15 ++-
 synapse/api/auth/msc3861_delegated.py         |  13 ++
 synapse/api/errors.py                         |   2 +
 synapse/config/user_directory.py              |   1 +
 synapse/handlers/admin.py                     |   1 +
 synapse/handlers/user_directory.py            |   5 +-
 synapse/rest/admin/users.py                   |  17 +++
 synapse/rest/client/logout.py                 |   8 +-
 .../storage/databases/main/registration.py    |  62 +++++++++-
 .../storage/databases/main/user_directory.py  |  11 +-
 .../main/delta/80/01_users_alter_locked.sql   |  16 +++
 tests/api/test_auth.py                        |   3 +
 tests/rest/admin/test_user.py                 | 111 +++++++++++++++++-
 tests/storage/test_registration.py            |   1 +
 19 files changed, 262 insertions(+), 11 deletions(-)
 create mode 100644 changelog.d/15870.feature
 create mode 100644 synapse/storage/schema/main/delta/80/01_users_alter_locked.sql

diff --git a/changelog.d/15870.feature b/changelog.d/15870.feature
new file mode 100644
index 000000000000..527220d637d8
--- /dev/null
+++ b/changelog.d/15870.feature
@@ -0,0 +1 @@
+Implements an admin API to lock an user without deactivating them. Based on [MSC3939](https://github.com/matrix-org/matrix-spec-proposals/pull/3939).
diff --git a/docs/admin_api/user_admin_api.md b/docs/admin_api/user_admin_api.md
index ac4f635099e6..c269ce6af0a5 100644
--- a/docs/admin_api/user_admin_api.md
+++ b/docs/admin_api/user_admin_api.md
@@ -146,6 +146,7 @@ Body parameters:
 - `admin` - **bool**, optional, defaults to `false`. Whether the user is a homeserver administrator,
   granting them access to the Admin API, among other things.
 - `deactivated` - **bool**, optional. If unspecified, deactivation state will be left unchanged.
+- `locked` - **bool**, optional. If unspecified, locked state will be left unchanged.
 
   Note: the `password` field must also be set if both of the following are true:
   - `deactivated` is set to `false` and the user was previously deactivated (you are reactivating this user)
diff --git a/docs/usage/configuration/config_documentation.md b/docs/usage/configuration/config_documentation.md
index 2987c9332d14..a17a8c290033 100644
--- a/docs/usage/configuration/config_documentation.md
+++ b/docs/usage/configuration/config_documentation.md
@@ -3631,6 +3631,7 @@ This option has the following sub-options:
 * `prefer_local_users`: Defines whether to prefer local users in search query results.
    If set to true, local users are more likely to appear above remote users when searching the
    user directory. Defaults to false.
+* `show_locked_users`: Defines whether to show locked users in search query results. Defaults to false.
 
 Example configuration:
 ```yaml
@@ -3638,6 +3639,7 @@ user_directory:
     enabled: false
     search_all_users: true
     prefer_local_users: true
+    show_locked_users: true
 ```
 ---
 ### `user_consent`
diff --git a/synapse/_scripts/synapse_port_db.py b/synapse/_scripts/synapse_port_db.py
index 22c84fbd5b3f..1300aaf63c92 100755
--- a/synapse/_scripts/synapse_port_db.py
+++ b/synapse/_scripts/synapse_port_db.py
@@ -123,7 +123,7 @@
     "redactions": ["have_censored"],
     "room_stats_state": ["is_federatable"],
     "rooms": ["is_public", "has_auth_chain_index"],
-    "users": ["shadow_banned", "approved"],
+    "users": ["shadow_banned", "approved", "locked"],
     "un_partial_stated_event_stream": ["rejection_status_changed"],
     "users_who_share_rooms": ["share_private"],
     "per_user_experimental_features": ["enabled"],
diff --git a/synapse/api/auth/__init__.py b/synapse/api/auth/__init__.py
index 90cfe39d7623..bb3f50f2ddbe 100644
--- a/synapse/api/auth/__init__.py
+++ b/synapse/api/auth/__init__.py
@@ -60,6 +60,7 @@ async def get_user_by_req(
         request: SynapseRequest,
         allow_guest: bool = False,
         allow_expired: bool = False,
+        allow_locked: bool = False,
     ) -> Requester:
         """Get a registered user's ID.
 
diff --git a/synapse/api/auth/internal.py b/synapse/api/auth/internal.py
index e2ae198b196e..6a5fd44ec01c 100644
--- a/synapse/api/auth/internal.py
+++ b/synapse/api/auth/internal.py
@@ -58,6 +58,7 @@ async def get_user_by_req(
         request: SynapseRequest,
         allow_guest: bool = False,
         allow_expired: bool = False,
+        allow_locked: bool = False,
     ) -> Requester:
         """Get a registered user's ID.
 
@@ -79,7 +80,7 @@ async def get_user_by_req(
         parent_span = active_span()
         with start_active_span("get_user_by_req"):
             requester = await self._wrapped_get_user_by_req(
-                request, allow_guest, allow_expired
+                request, allow_guest, allow_expired, allow_locked
             )
 
             if parent_span:
@@ -107,6 +108,7 @@ async def _wrapped_get_user_by_req(
         request: SynapseRequest,
         allow_guest: bool,
         allow_expired: bool,
+        allow_locked: bool,
     ) -> Requester:
         """Helper for get_user_by_req
 
@@ -126,6 +128,17 @@ async def _wrapped_get_user_by_req(
                     access_token, allow_expired=allow_expired
                 )
 
+                # Deny the request if the user account is locked.
+                if not allow_locked and await self.store.get_user_locked_status(
+                    requester.user.to_string()
+                ):
+                    raise AuthError(
+                        401,
+                        "User account has been locked",
+                        errcode=Codes.USER_LOCKED,
+                        additional_fields={"soft_logout": True},
+                    )
+
                 # Deny the request if the user account has expired.
                 # This check is only done for regular users, not appservice ones.
                 if not allow_expired:
diff --git a/synapse/api/auth/msc3861_delegated.py b/synapse/api/auth/msc3861_delegated.py
index bd4fc9c0ee3d..9524102a3037 100644
--- a/synapse/api/auth/msc3861_delegated.py
+++ b/synapse/api/auth/msc3861_delegated.py
@@ -27,6 +27,7 @@
 from synapse.api.auth.base import BaseAuth
 from synapse.api.errors import (
     AuthError,
+    Codes,
     HttpResponseException,
     InvalidClientTokenError,
     OAuthInsufficientScopeError,
@@ -196,6 +197,7 @@ async def get_user_by_req(
         request: SynapseRequest,
         allow_guest: bool = False,
         allow_expired: bool = False,
+        allow_locked: bool = False,
     ) -> Requester:
         access_token = self.get_access_token_from_request(request)
 
@@ -205,6 +207,17 @@ async def get_user_by_req(
             # so that we don't provision the user if they don't have enough permission:
             requester = await self.get_user_by_access_token(access_token, allow_expired)
 
+            # Deny the request if the user account is locked.
+            if not allow_locked and await self.store.get_user_locked_status(
+                requester.user.to_string()
+            ):
+                raise AuthError(
+                    401,
+                    "User account has been locked",
+                    errcode=Codes.USER_LOCKED,
+                    additional_fields={"soft_logout": True},
+                )
+
         if not allow_guest and requester.is_guest:
             raise OAuthInsufficientScopeError([SCOPE_MATRIX_API])
 
diff --git a/synapse/api/errors.py b/synapse/api/errors.py
index 3546aaf7c399..7ffd72c42cd4 100644
--- a/synapse/api/errors.py
+++ b/synapse/api/errors.py
@@ -80,6 +80,8 @@ class Codes(str, Enum):
     WEAK_PASSWORD = "M_WEAK_PASSWORD"
     INVALID_SIGNATURE = "M_INVALID_SIGNATURE"
     USER_DEACTIVATED = "M_USER_DEACTIVATED"
+    # USER_LOCKED = "M_USER_LOCKED"
+    USER_LOCKED = "ORG_MATRIX_MSC3939_USER_LOCKED"
 
     # Part of MSC3848
     # https://github.com/matrix-org/matrix-spec-proposals/pull/3848
diff --git a/synapse/config/user_directory.py b/synapse/config/user_directory.py
index c9e18b91e9d2..f60ec2ea66b7 100644
--- a/synapse/config/user_directory.py
+++ b/synapse/config/user_directory.py
@@ -35,3 +35,4 @@ def read_config(self, config: JsonDict, **kwargs: Any) -> None:
         self.user_directory_search_prefer_local_users = user_directory_config.get(
             "prefer_local_users", False
         )
+        self.show_locked_users = user_directory_config.get("show_locked_users", False)
diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py
index 119c7f838481..0e812a6d8b51 100644
--- a/synapse/handlers/admin.py
+++ b/synapse/handlers/admin.py
@@ -67,6 +67,7 @@ async def get_user(self, user: UserID) -> Optional[JsonDict]:
             "name",
             "admin",
             "deactivated",
+            "locked",
             "shadow_banned",
             "creation_ts",
             "appservice_id",
diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py
index 05197edc9546..a0f5568000f0 100644
--- a/synapse/handlers/user_directory.py
+++ b/synapse/handlers/user_directory.py
@@ -94,6 +94,7 @@ def __init__(self, hs: "HomeServer"):
         self.is_mine_id = hs.is_mine_id
         self.update_user_directory = hs.config.worker.should_update_user_directory
         self.search_all_users = hs.config.userdirectory.user_directory_search_all_users
+        self.show_locked_users = hs.config.userdirectory.show_locked_users
         self._spam_checker_module_callbacks = hs.get_module_api_callbacks().spam_checker
         self._hs = hs
 
@@ -144,7 +145,9 @@ async def search_users(
                     ]
                 }
         """
-        results = await self.store.search_user_dir(user_id, search_term, limit)
+        results = await self.store.search_user_dir(
+            user_id, search_term, limit, self.show_locked_users
+        )
 
         # Remove any spammy users from the results.
         non_spammy_users = []
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index e0257daa751d..04d9ef25b78e 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -280,6 +280,17 @@ async def on_PUT(
                 HTTPStatus.BAD_REQUEST, "'deactivated' parameter is not of type boolean"
             )
 
+        lock = body.get("locked", False)
+        if not isinstance(lock, bool):
+            raise SynapseError(
+                HTTPStatus.BAD_REQUEST, "'locked' parameter is not of type boolean"
+            )
+
+        if deactivate and lock:
+            raise SynapseError(
+                HTTPStatus.BAD_REQUEST, "An user can't be deactivated and locked"
+            )
+
         approved: Optional[bool] = None
         if "approved" in body and self._msc3866_enabled:
             approved = body["approved"]
@@ -397,6 +408,12 @@ async def on_PUT(
                         target_user.to_string()
                     )
 
+            if "locked" in body:
+                if lock and not user["locked"]:
+                    await self.store.set_user_locked_status(user_id, True)
+                elif not lock and user["locked"]:
+                    await self.store.set_user_locked_status(user_id, False)
+
             if "user_type" in body:
                 await self.store.set_user_type(target_user, user_type)
 
diff --git a/synapse/rest/client/logout.py b/synapse/rest/client/logout.py
index 94ad90942f39..2e104d488889 100644
--- a/synapse/rest/client/logout.py
+++ b/synapse/rest/client/logout.py
@@ -40,7 +40,9 @@ def __init__(self, hs: "HomeServer"):
         self._device_handler = handler
 
     async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
-        requester = await self.auth.get_user_by_req(request, allow_expired=True)
+        requester = await self.auth.get_user_by_req(
+            request, allow_expired=True, allow_locked=True
+        )
 
         if requester.device_id is None:
             # The access token wasn't associated with a device.
@@ -67,7 +69,9 @@ def __init__(self, hs: "HomeServer"):
         self._device_handler = handler
 
     async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
-        requester = await self.auth.get_user_by_req(request, allow_expired=True)
+        requester = await self.auth.get_user_by_req(
+            request, allow_expired=True, allow_locked=True
+        )
         user_id = requester.user.to_string()
 
         # first delete all of the user's devices
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index c582cf05732d..d3a01d526fb8 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -205,7 +205,8 @@ def get_user_by_id_txn(txn: LoggingTransaction) -> Optional[Dict[str, Any]]:
                     name, password_hash, is_guest, admin, consent_version, consent_ts,
                     consent_server_notice_sent, appservice_id, creation_ts, user_type,
                     deactivated, COALESCE(shadow_banned, FALSE) AS shadow_banned,
-                    COALESCE(approved, TRUE) AS approved
+                    COALESCE(approved, TRUE) AS approved,
+                    COALESCE(locked, FALSE) AS locked
                 FROM users
                 WHERE name = ?
                 """,
@@ -230,10 +231,15 @@ def get_user_by_id_txn(txn: LoggingTransaction) -> Optional[Dict[str, Any]]:
             # want to make sure we're returning the right type of data.
             # Note: when adding a column name to this list, be wary of NULLable columns,
             # since NULL values will be turned into False.
-            boolean_columns = ["admin", "deactivated", "shadow_banned", "approved"]
+            boolean_columns = [
+                "admin",
+                "deactivated",
+                "shadow_banned",
+                "approved",
+                "locked",
+            ]
             for column in boolean_columns:
-                if not isinstance(row[column], bool):
-                    row[column] = bool(row[column])
+                row[column] = bool(row[column])
 
         return row
 
@@ -1116,6 +1122,27 @@ async def get_user_deactivated_status(self, user_id: str) -> bool:
         # Convert the integer into a boolean.
         return res == 1
 
+    @cached()
+    async def get_user_locked_status(self, user_id: str) -> bool:
+        """Retrieve the value for the `locked` property for the provided user.
+
+        Args:
+            user_id: The ID of the user to retrieve the status for.
+
+        Returns:
+            True if the user was locked, false if the user is still active.
+        """
+
+        res = await self.db_pool.simple_select_one_onecol(
+            table="users",
+            keyvalues={"name": user_id},
+            retcol="locked",
+            desc="get_user_locked_status",
+        )
+
+        # Convert the potential integer into a boolean.
+        return bool(res)
+
     async def get_threepid_validation_session(
         self,
         medium: Optional[str],
@@ -2111,6 +2138,33 @@ def set_user_deactivated_status_txn(
         self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
         txn.call_after(self.is_guest.invalidate, (user_id,))
 
+    async def set_user_locked_status(self, user_id: str, locked: bool) -> None:
+        """Set the `locked` property for the provided user to the provided value.
+
+        Args:
+            user_id: The ID of the user to set the status for.
+            locked: The value to set for `locked`.
+        """
+
+        await self.db_pool.runInteraction(
+            "set_user_locked_status",
+            self.set_user_locked_status_txn,
+            user_id,
+            locked,
+        )
+
+    def set_user_locked_status_txn(
+        self, txn: LoggingTransaction, user_id: str, locked: bool
+    ) -> None:
+        self.db_pool.simple_update_one_txn(
+            txn=txn,
+            table="users",
+            keyvalues={"name": user_id},
+            updatevalues={"locked": locked},
+        )
+        self._invalidate_cache_and_stream(txn, self.get_user_locked_status, (user_id,))
+        self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
+
     def update_user_approval_status_txn(
         self, txn: LoggingTransaction, user_id: str, approved: bool
     ) -> None:
diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index 2a136f2ff6e8..f0dc31fee649 100644
--- a/synapse/storage/databases/main/user_directory.py
+++ b/synapse/storage/databases/main/user_directory.py
@@ -995,7 +995,11 @@ async def get_user_directory_stream_pos(self) -> Optional[int]:
         )
 
     async def search_user_dir(
-        self, user_id: str, search_term: str, limit: int
+        self,
+        user_id: str,
+        search_term: str,
+        limit: int,
+        show_locked_users: bool = False,
     ) -> SearchResult:
         """Searches for users in directory
 
@@ -1029,6 +1033,9 @@ async def search_user_dir(
                 )
             """
 
+        if not show_locked_users:
+            where_clause += " AND (u.locked IS NULL OR u.locked = FALSE)"
+
         # We allow manipulating the ranking algorithm by injecting statements
         # based on config options.
         additional_ordering_statements = []
@@ -1060,6 +1067,7 @@ async def search_user_dir(
                 SELECT d.user_id AS user_id, display_name, avatar_url
                 FROM matching_users as t
                 INNER JOIN user_directory AS d USING (user_id)
+                LEFT JOIN users AS u ON t.user_id = u.name
                 WHERE
                     %(where_clause)s
                 ORDER BY
@@ -1115,6 +1123,7 @@ async def search_user_dir(
                 SELECT d.user_id AS user_id, display_name, avatar_url
                 FROM user_directory_search as t
                 INNER JOIN user_directory AS d USING (user_id)
+                LEFT JOIN users AS u ON t.user_id = u.name
                 WHERE
                     %(where_clause)s
                     AND value MATCH ?
diff --git a/synapse/storage/schema/main/delta/80/01_users_alter_locked.sql b/synapse/storage/schema/main/delta/80/01_users_alter_locked.sql
new file mode 100644
index 000000000000..21c79714412c
--- /dev/null
+++ b/synapse/storage/schema/main/delta/80/01_users_alter_locked.sql
@@ -0,0 +1,16 @@
+/* Copyright 2023 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+ALTER TABLE users ADD locked BOOLEAN DEFAULT FALSE NOT NULL;
diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py
index cdb0048122a1..ce96574915fd 100644
--- a/tests/api/test_auth.py
+++ b/tests/api/test_auth.py
@@ -69,6 +69,7 @@ def test_get_user_by_req_user_valid_token(self) -> None:
         )
         self.store.get_user_by_access_token = simple_async_mock(user_info)
         self.store.mark_access_token_as_used = simple_async_mock(None)
+        self.store.get_user_locked_status = simple_async_mock(False)
 
         request = Mock(args={})
         request.args[b"access_token"] = [self.test_token]
@@ -293,6 +294,7 @@ def test_get_user_by_req__puppeted_token__not_tracking_puppeted_mau(self) -> Non
         )
         self.store.insert_client_ip = simple_async_mock(None)
         self.store.mark_access_token_as_used = simple_async_mock(None)
+        self.store.get_user_locked_status = simple_async_mock(False)
         request = Mock(args={})
         request.getClientAddress.return_value.host = "127.0.0.1"
         request.args[b"access_token"] = [self.test_token]
@@ -311,6 +313,7 @@ def test_get_user_by_req__puppeted_token__tracking_puppeted_mau(self) -> None:
                 token_used=True,
             )
         )
+        self.store.get_user_locked_status = simple_async_mock(False)
         self.store.insert_client_ip = simple_async_mock(None)
         self.store.mark_access_token_as_used = simple_async_mock(None)
         request = Mock(args={})
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index 9af9db6e3eb7..41a959b4d6c7 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -29,7 +29,16 @@
 from synapse.api.errors import Codes, HttpResponseException, ResourceLimitError
 from synapse.api.room_versions import RoomVersions
 from synapse.media.filepath import MediaFilePaths
-from synapse.rest.client import devices, login, logout, profile, register, room, sync
+from synapse.rest.client import (
+    devices,
+    login,
+    logout,
+    profile,
+    register,
+    room,
+    sync,
+    user_directory,
+)
 from synapse.server import HomeServer
 from synapse.types import JsonDict, UserID, create_requester
 from synapse.util import Clock
@@ -1477,6 +1486,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         login.register_servlets,
         sync.register_servlets,
         register.register_servlets,
+        user_directory.register_servlets,
     ]
 
     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
@@ -2464,6 +2474,105 @@ def test_deactivate_user(self) -> None:
         # This key was removed intentionally. Ensure it is not accidentally re-included.
         self.assertNotIn("password_hash", channel.json_body)
 
+    def test_locked_user(self) -> None:
+        # User can sync
+        channel = self.make_request(
+            "GET",
+            "/_matrix/client/v3/sync",
+            access_token=self.other_user_token,
+        )
+        self.assertEqual(200, channel.code, msg=channel.json_body)
+
+        # Lock user
+        channel = self.make_request(
+            "PUT",
+            self.url_other_user,
+            access_token=self.admin_user_tok,
+            content={"locked": True},
+        )
+
+        # User is not authorized to sync anymore
+        channel = self.make_request(
+            "GET",
+            "/_matrix/client/v3/sync",
+            access_token=self.other_user_token,
+        )
+        self.assertEqual(401, channel.code, msg=channel.json_body)
+        self.assertEqual(Codes.USER_LOCKED, channel.json_body["errcode"])
+        self.assertTrue(channel.json_body["soft_logout"])
+
+    @override_config({"user_directory": {"enabled": True, "search_all_users": True}})
+    def test_locked_user_not_in_user_dir(self) -> None:
+        # User is available in the user dir
+        channel = self.make_request(
+            "POST",
+            "/_matrix/client/v3/user_directory/search",
+            {"search_term": self.other_user},
+            access_token=self.admin_user_tok,
+        )
+        self.assertEqual(200, channel.code, msg=channel.json_body)
+        self.assertIn("results", channel.json_body)
+        self.assertEqual(1, len(channel.json_body["results"]))
+
+        # Lock user
+        channel = self.make_request(
+            "PUT",
+            self.url_other_user,
+            access_token=self.admin_user_tok,
+            content={"locked": True},
+        )
+
+        # User is not available anymore in the user dir
+        channel = self.make_request(
+            "POST",
+            "/_matrix/client/v3/user_directory/search",
+            {"search_term": self.other_user},
+            access_token=self.admin_user_tok,
+        )
+        self.assertEqual(200, channel.code, msg=channel.json_body)
+        self.assertIn("results", channel.json_body)
+        self.assertEqual(0, len(channel.json_body["results"]))
+
+    @override_config(
+        {
+            "user_directory": {
+                "enabled": True,
+                "search_all_users": True,
+                "show_locked_users": True,
+            }
+        }
+    )
+    def test_locked_user_in_user_dir_with_show_locked_users_option(self) -> None:
+        # User is available in the user dir
+        channel = self.make_request(
+            "POST",
+            "/_matrix/client/v3/user_directory/search",
+            {"search_term": self.other_user},
+            access_token=self.admin_user_tok,
+        )
+        self.assertEqual(200, channel.code, msg=channel.json_body)
+        self.assertIn("results", channel.json_body)
+        self.assertEqual(1, len(channel.json_body["results"]))
+
+        # Lock user
+        channel = self.make_request(
+            "PUT",
+            self.url_other_user,
+            access_token=self.admin_user_tok,
+            content={"locked": True},
+        )
+
+        # User is still available in the user dir
+        channel = self.make_request(
+            "POST",
+            "/_matrix/client/v3/user_directory/search",
+            {"search_term": self.other_user},
+            access_token=self.admin_user_tok,
+        )
+        self.assertEqual(200, channel.code, msg=channel.json_body)
+        self.assertIn("results", channel.json_body)
+        self.assertEqual(1, len(channel.json_body["results"]))
+
     @override_config({"user_directory": {"enabled": True, "search_all_users": True}})
     def test_change_name_deactivate_user_user_directory(self) -> None:
         """
diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py
index 05ea802008ad..ba41459d083d 100644
--- a/tests/storage/test_registration.py
+++ b/tests/storage/test_registration.py
@@ -48,6 +48,7 @@ def test_register(self) -> None:
                 "creation_ts": 0,
                 "user_type": None,
                 "deactivated": 0,
+                "locked": 0,
                 "shadow_banned": 0,
                 "approved": 1,
             },

From efd4d06d7694e269f1d85e697104e742a984da18 Mon Sep 17 00:00:00 2001
From: Patrick Cloke <clokep@users.noreply.github.com>
Date: Thu, 10 Aug 2023 07:39:46 -0400
Subject: [PATCH 2/4] Clean-up presence code (#16092)

Misc. clean-ups to:

* Use keyword arguments.
* Return early (reducing indentation) of some functions.
* Removing duplicated / unused code.
* Use wrap_as_background_process.
---
 changelog.d/16092.misc       |   1 +
 synapse/handlers/presence.py | 169 ++++++++++++++++-------------------
 2 files changed, 76 insertions(+), 94 deletions(-)
 create mode 100644 changelog.d/16092.misc

diff --git a/changelog.d/16092.misc b/changelog.d/16092.misc
new file mode 100644
index 000000000000..b52080777105
--- /dev/null
+++ b/changelog.d/16092.misc
@@ -0,0 +1 @@
+Clean-up the presence code.
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index cd7df0525f4f..11dff724e665 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -30,7 +30,6 @@
 from typing import (
     TYPE_CHECKING,
     Any,
-    Awaitable,
     Callable,
     Collection,
     Dict,
@@ -54,7 +53,10 @@
 from synapse.events.presence_router import PresenceRouter
 from synapse.logging.context import run_in_background
 from synapse.metrics import LaterGauge
-from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.metrics.background_process_metrics import (
+    run_as_background_process,
+    wrap_as_background_process,
+)
 from synapse.replication.http.presence import (
     ReplicationBumpPresenceActiveTime,
     ReplicationPresenceSetState,
@@ -141,6 +143,8 @@ def __init__(self, hs: "HomeServer"):
         self.state = hs.get_state_handler()
         self.is_mine_id = hs.is_mine_id
 
+        self._presence_enabled = hs.config.server.use_presence
+
         self._federation = None
         if hs.should_send_federation():
             self._federation = hs.get_federation_sender()
@@ -149,6 +153,15 @@ def __init__(self, hs: "HomeServer"):
 
         self._busy_presence_enabled = hs.config.experimental.msc3026_enabled
 
+        self.VALID_PRESENCE: Tuple[str, ...] = (
+            PresenceState.ONLINE,
+            PresenceState.UNAVAILABLE,
+            PresenceState.OFFLINE,
+        )
+
+        if self._busy_presence_enabled:
+            self.VALID_PRESENCE += (PresenceState.BUSY,)
+
         active_presence = self.store.take_presence_startup_info()
         self.user_to_current_state = {state.user_id: state for state in active_presence}
 
@@ -395,8 +408,6 @@ def __init__(self, hs: "HomeServer"):
 
         self._presence_writer_instance = hs.config.worker.writers.presence[0]
 
-        self._presence_enabled = hs.config.server.use_presence
-
         # Route presence EDUs to the right worker
         hs.get_federation_registry().register_instances_for_edu(
             EduTypes.PRESENCE,
@@ -421,8 +432,6 @@ def __init__(self, hs: "HomeServer"):
             self.send_stop_syncing, UPDATE_SYNCING_USERS_MS
         )
 
-        self._busy_presence_enabled = hs.config.experimental.msc3026_enabled
-
         hs.get_reactor().addSystemEventTrigger(
             "before",
             "shutdown",
@@ -490,7 +499,9 @@ async def user_syncing(
             # what the spec wants: see comment in the BasePresenceHandler version
             # of this function.
             await self.set_state(
-                UserID.from_string(user_id), {"presence": presence_state}, True
+                UserID.from_string(user_id),
+                {"presence": presence_state},
+                ignore_status_msg=True,
             )
 
         curr_sync = self._user_to_num_current_syncs.get(user_id, 0)
@@ -601,22 +612,13 @@ async def set_state(
         """
         presence = state["presence"]
 
-        valid_presence = (
-            PresenceState.ONLINE,
-            PresenceState.UNAVAILABLE,
-            PresenceState.OFFLINE,
-            PresenceState.BUSY,
-        )
-
-        if presence not in valid_presence or (
-            presence == PresenceState.BUSY and not self._busy_presence_enabled
-        ):
+        if presence not in self.VALID_PRESENCE:
             raise SynapseError(400, "Invalid presence state")
 
         user_id = target_user.to_string()
 
         # If presence is disabled, no-op
-        if not self.hs.config.server.use_presence:
+        if not self._presence_enabled:
             return
 
         # Proxy request to instance that writes presence
@@ -633,7 +635,7 @@ async def bump_presence_active_time(self, user: UserID) -> None:
         with the app.
         """
         # If presence is disabled, no-op
-        if not self.hs.config.server.use_presence:
+        if not self._presence_enabled:
             return
 
         # Proxy request to instance that writes presence
@@ -649,7 +651,6 @@ def __init__(self, hs: "HomeServer"):
         self.hs = hs
         self.wheel_timer: WheelTimer[str] = WheelTimer()
         self.notifier = hs.get_notifier()
-        self._presence_enabled = hs.config.server.use_presence
 
         federation_registry = hs.get_federation_registry()
 
@@ -700,8 +701,6 @@ def __init__(self, hs: "HomeServer"):
             self._on_shutdown,
         )
 
-        self._next_serial = 1
-
         # Keeps track of the number of *ongoing* syncs on this process. While
         # this is non zero a user will never go offline.
         self.user_to_num_current_syncs: Dict[str, int] = {}
@@ -723,21 +722,16 @@ def __init__(self, hs: "HomeServer"):
             # Start a LoopingCall in 30s that fires every 5s.
             # The initial delay is to allow disconnected clients a chance to
             # reconnect before we treat them as offline.
-            def run_timeout_handler() -> Awaitable[None]:
-                return run_as_background_process(
-                    "handle_presence_timeouts", self._handle_timeouts
-                )
-
             self.clock.call_later(
-                30, self.clock.looping_call, run_timeout_handler, 5000
+                30, self.clock.looping_call, self._handle_timeouts, 5000
             )
 
-            def run_persister() -> Awaitable[None]:
-                return run_as_background_process(
-                    "persist_presence_changes", self._persist_unpersisted_changes
-                )
-
-            self.clock.call_later(60, self.clock.looping_call, run_persister, 60 * 1000)
+            self.clock.call_later(
+                60,
+                self.clock.looping_call,
+                self._persist_unpersisted_changes,
+                60 * 1000,
+            )
 
         LaterGauge(
             "synapse_handlers_presence_wheel_timer_size",
@@ -783,6 +777,7 @@ async def _on_shutdown(self) -> None:
             )
         logger.info("Finished _on_shutdown")
 
+    @wrap_as_background_process("persist_presence_changes")
     async def _persist_unpersisted_changes(self) -> None:
         """We periodically persist the unpersisted changes, as otherwise they
         may stack up and slow down shutdown times.
@@ -898,6 +893,7 @@ async def _update_states(
                         states, [destination]
                     )
 
+    @wrap_as_background_process("handle_presence_timeouts")
     async def _handle_timeouts(self) -> None:
         """Checks the presence of users that have timed out and updates as
         appropriate.
@@ -955,7 +951,7 @@ async def bump_presence_active_time(self, user: UserID) -> None:
         with the app.
         """
         # If presence is disabled, no-op
-        if not self.hs.config.server.use_presence:
+        if not self._presence_enabled:
             return
 
         user_id = user.to_string()
@@ -990,56 +986,51 @@ async def user_syncing(
                 client that is being used by a user.
             presence_state: The presence state indicated in the sync request
         """
-        # Override if it should affect the user's presence, if presence is
-        # disabled.
-        if not self.hs.config.server.use_presence:
-            affect_presence = False
+        if not affect_presence or not self._presence_enabled:
+            return _NullContextManager()
 
-        if affect_presence:
-            curr_sync = self.user_to_num_current_syncs.get(user_id, 0)
-            self.user_to_num_current_syncs[user_id] = curr_sync + 1
+        curr_sync = self.user_to_num_current_syncs.get(user_id, 0)
+        self.user_to_num_current_syncs[user_id] = curr_sync + 1
 
-            prev_state = await self.current_state_for_user(user_id)
+        prev_state = await self.current_state_for_user(user_id)
 
-            # If they're busy then they don't stop being busy just by syncing,
-            # so just update the last sync time.
-            if prev_state.state != PresenceState.BUSY:
-                # XXX: We set_state separately here and just update the last_active_ts above
-                # This keeps the logic as similar as possible between the worker and single
-                # process modes. Using set_state will actually cause last_active_ts to be
-                # updated always, which is not what the spec calls for, but synapse has done
-                # this for... forever, I think.
-                await self.set_state(
-                    UserID.from_string(user_id), {"presence": presence_state}, True
-                )
-                # Retrieve the new state for the logic below. This should come from the
-                # in-memory cache.
-                prev_state = await self.current_state_for_user(user_id)
+        # If they're busy then they don't stop being busy just by syncing,
+        # so just update the last sync time.
+        if prev_state.state != PresenceState.BUSY:
+            # XXX: We set_state separately here and just update the last_active_ts above
+            # This keeps the logic as similar as possible between the worker and single
+            # process modes. Using set_state will actually cause last_active_ts to be
+            # updated always, which is not what the spec calls for, but synapse has done
+            # this for... forever, I think.
+            await self.set_state(
+                UserID.from_string(user_id),
+                {"presence": presence_state},
+                ignore_status_msg=True,
+            )
+            # Retrieve the new state for the logic below. This should come from the
+            # in-memory cache.
+            prev_state = await self.current_state_for_user(user_id)
 
-            # To keep the single process behaviour consistent with worker mode, run the
-            # same logic as `update_external_syncs_row`, even though it looks weird.
-            if prev_state.state == PresenceState.OFFLINE:
-                await self._update_states(
-                    [
-                        prev_state.copy_and_replace(
-                            state=PresenceState.ONLINE,
-                            last_active_ts=self.clock.time_msec(),
-                            last_user_sync_ts=self.clock.time_msec(),
-                        )
-                    ]
-                )
-            # otherwise, set the new presence state & update the last sync time,
-            # but don't update last_active_ts as this isn't an indication that
-            # they've been active (even though it's probably been updated by
-            # set_state above)
-            else:
-                await self._update_states(
-                    [
-                        prev_state.copy_and_replace(
-                            last_user_sync_ts=self.clock.time_msec()
-                        )
-                    ]
-                )
+        # To keep the single process behaviour consistent with worker mode, run the
+        # same logic as `update_external_syncs_row`, even though it looks weird.
+        if prev_state.state == PresenceState.OFFLINE:
+            await self._update_states(
+                [
+                    prev_state.copy_and_replace(
+                        state=PresenceState.ONLINE,
+                        last_active_ts=self.clock.time_msec(),
+                        last_user_sync_ts=self.clock.time_msec(),
+                    )
+                ]
+            )
+        # otherwise, set the new presence state & update the last sync time,
+        # but don't update last_active_ts as this isn't an indication that
+        # they've been active (even though it's probably been updated by
+        # set_state above)
+        else:
+            await self._update_states(
+                [prev_state.copy_and_replace(last_user_sync_ts=self.clock.time_msec())]
+            )
 
         async def _end() -> None:
             try:
@@ -1061,8 +1052,7 @@ def _user_syncing() -> Generator[None, None, None]:
             try:
                 yield
             finally:
-                if affect_presence:
-                    run_in_background(_end)
+                run_in_background(_end)
 
         return _user_syncing()
 
@@ -1229,20 +1219,11 @@ async def set_state(
         status_msg = state.get("status_msg", None)
         presence = state["presence"]
 
-        valid_presence = (
-            PresenceState.ONLINE,
-            PresenceState.UNAVAILABLE,
-            PresenceState.OFFLINE,
-            PresenceState.BUSY,
-        )
-
-        if presence not in valid_presence or (
-            presence == PresenceState.BUSY and not self._busy_presence_enabled
-        ):
+        if presence not in self.VALID_PRESENCE:
             raise SynapseError(400, "Invalid presence state")
 
         # If presence is disabled, no-op
-        if not self.hs.config.server.use_presence:
+        if not self._presence_enabled:
             return
 
         user_id = target_user.to_string()

From 7f4b41369049c143919d229670087df69edb9602 Mon Sep 17 00:00:00 2001
From: reivilibre <oliverw@matrix.org>
Date: Thu, 10 Aug 2023 17:28:31 +0000
Subject: [PATCH 3/4] Fix the type annotation on `run_db_interaction` in the
 Module API. (#16089)

* Fix the method signature of `run_db_interaction` on the module API

* Newsfile

Signed-off-by: Olivier Wilkinson (reivilibre) <oliverw@matrix.org>

---------

Signed-off-by: Olivier Wilkinson (reivilibre) <oliverw@matrix.org>
---
 changelog.d/16089.misc         | 1 +
 synapse/module_api/__init__.py | 4 ++--
 2 files changed, 3 insertions(+), 2 deletions(-)
 create mode 100644 changelog.d/16089.misc

diff --git a/changelog.d/16089.misc b/changelog.d/16089.misc
new file mode 100644
index 000000000000..8c302e6884d1
--- /dev/null
+++ b/changelog.d/16089.misc
@@ -0,0 +1 @@
+Fix the type annotation on `run_db_interaction` in the Module API.
\ No newline at end of file
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index acee1dafd3ae..9ad8e038aee2 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -31,7 +31,7 @@
 
 import attr
 import jinja2
-from typing_extensions import ParamSpec
+from typing_extensions import Concatenate, ParamSpec
 
 from twisted.internet import defer
 from twisted.internet.interfaces import IDelayedCall
@@ -885,7 +885,7 @@ def invalidate_access_token(
     def run_db_interaction(
         self,
         desc: str,
-        func: Callable[P, T],
+        func: Callable[Concatenate[LoggingTransaction, P], T],
         *args: P.args,
         **kwargs: P.kwargs,
     ) -> "defer.Deferred[T]":

From 614efc488b1a25dfa32256930c5acc896c88d92f Mon Sep 17 00:00:00 2001
From: Nick Mills-Barrett <nick@beeper.com>
Date: Fri, 11 Aug 2023 12:37:09 +0100
Subject: [PATCH 4/4] Add linearizer on user ID to push rule PUT/DELETE
 requests (#16052)

See: #16053

Signed off by Nick @ Beeper (@Fizzadar)
---
 changelog.d/16052.bugfix         |  1 +
 synapse/rest/client/push_rule.py | 28 ++++++++++++++++++++++------
 2 files changed, 23 insertions(+), 6 deletions(-)
 create mode 100644 changelog.d/16052.bugfix

diff --git a/changelog.d/16052.bugfix b/changelog.d/16052.bugfix
new file mode 100644
index 000000000000..3c7a60f226a7
--- /dev/null
+++ b/changelog.d/16052.bugfix
@@ -0,0 +1 @@
+Fix long-standing bug where concurrent requests to change a user's push rules could cause a deadlock. Contributed by Nick @ Beeper (@fizzadar).
diff --git a/synapse/rest/client/push_rule.py b/synapse/rest/client/push_rule.py
index 5c9fece3ba33..5ed3b83a03e2 100644
--- a/synapse/rest/client/push_rule.py
+++ b/synapse/rest/client/push_rule.py
@@ -32,6 +32,7 @@
 from synapse.rest.client._base import client_patterns
 from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException
 from synapse.types import JsonDict
+from synapse.util.async_helpers import Linearizer
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -53,26 +54,32 @@ def __init__(self, hs: "HomeServer"):
         self.notifier = hs.get_notifier()
         self._is_worker = hs.config.worker.worker_app is not None
         self._push_rules_handler = hs.get_push_rules_handler()
+        self._push_rule_linearizer = Linearizer(name="push_rules")
 
     async def on_PUT(self, request: SynapseRequest, path: str) -> Tuple[int, JsonDict]:
         if self._is_worker:
             raise Exception("Cannot handle PUT /push_rules on worker")
 
+        requester = await self.auth.get_user_by_req(request)
+        user_id = requester.user.to_string()
+
+        async with self._push_rule_linearizer.queue(user_id):
+            return await self.handle_put(request, path, user_id)
+
+    async def handle_put(
+        self, request: SynapseRequest, path: str, user_id: str
+    ) -> Tuple[int, JsonDict]:
         spec = _rule_spec_from_path(path.split("/"))
         try:
             priority_class = _priority_class_from_spec(spec)
         except InvalidRuleException as e:
             raise SynapseError(400, str(e))
 
-        requester = await self.auth.get_user_by_req(request)
-
         if "/" in spec.rule_id or "\\" in spec.rule_id:
             raise SynapseError(400, "rule_id may not contain slashes")
 
         content = parse_json_value_from_request(request)
 
-        user_id = requester.user.to_string()
-
         if spec.attr:
             try:
                 await self._push_rules_handler.set_rule_attr(user_id, spec, content)
@@ -126,11 +133,20 @@ async def on_DELETE(
         if self._is_worker:
             raise Exception("Cannot handle DELETE /push_rules on worker")
 
-        spec = _rule_spec_from_path(path.split("/"))
-
         requester = await self.auth.get_user_by_req(request)
         user_id = requester.user.to_string()
 
+        async with self._push_rule_linearizer.queue(user_id):
+            return await self.handle_delete(request, path, user_id)
+
+    async def handle_delete(
+        self,
+        request: SynapseRequest,
+        path: str,
+        user_id: str,
+    ) -> Tuple[int, JsonDict]:
+        spec = _rule_spec_from_path(path.split("/"))
+
         namespaced_rule_id = f"global/{spec.template}/{spec.rule_id}"
 
         try: