From ebff1473474bd20dfa7aefbd3f8a79d7c51c6272 Mon Sep 17 00:00:00 2001 From: "H. Shay" Date: Thu, 4 Apr 2024 13:14:50 -0700 Subject: [PATCH] add functionality to get/set suspended status --- .../storage/databases/main/registration.py | 54 ++++++++++++++++++- tests/storage/test_registration.py | 2 +- 2 files changed, 54 insertions(+), 2 deletions(-) diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index d939ade4271..1dc23c392ab 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -236,7 +236,8 @@ def get_user_by_id_txn(txn: LoggingTransaction) -> Optional[UserInfo]: consent_server_notice_sent, appservice_id, creation_ts, user_type, deactivated, COALESCE(shadow_banned, FALSE) AS shadow_banned, COALESCE(approved, TRUE) AS approved, - COALESCE(locked, FALSE) AS locked + COALESCE(locked, FALSE) AS locked, + COALESCE(suspended, FALSE) AS suspended FROM users WHERE name = ? """, @@ -261,6 +262,7 @@ def get_user_by_id_txn(txn: LoggingTransaction) -> Optional[UserInfo]: shadow_banned, approved, locked, + suspended, ) = row return UserInfo( @@ -277,6 +279,7 @@ def get_user_by_id_txn(txn: LoggingTransaction) -> Optional[UserInfo]: user_type=user_type, approved=bool(approved), locked=bool(locked), + suspended=bool(suspended), ) return await self.db_pool.runInteraction( @@ -1180,6 +1183,26 @@ async def get_user_locked_status(self, user_id: str) -> bool: # Convert the potential integer into a boolean. return bool(res) + @cached() + async def get_user_suspended_status(self, user_id: str) -> bool: + """ + Determine whether the user's account is suspended. + Args: + user_id: The user ID of the user in question + Returns: + True if the user's account is suspended, false if not. + """ + + res = await self.db_pool.simple_select_one_onecol( + table="users", + keyvalues={"name": user_id}, + retcol="suspended", + allow_none=True, + desc="get_user_suspended", + ) + + return bool(res) + async def get_threepid_validation_session( self, medium: Optional[str], @@ -2206,6 +2229,35 @@ def set_user_deactivated_status_txn( self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) txn.call_after(self.is_guest.invalidate, (user_id,)) + async def set_user_suspended_status(self, user_id: str, suspended: bool) -> None: + """ + Set whether the user's account is suspended in the `users` table. + + Args: + user_id: The user ID of the user in question + suspended: True if the user is suspended, false if not + """ + await self.db_pool.runInteraction( + "set_user_suspended_status", + self.set_user_suspended_status_txn, + user_id, + suspended, + ) + + def set_user_suspended_status_txn( + self, txn: LoggingTransaction, user_id: str, suspended: bool + ) -> None: + self.db_pool.simple_update_one_txn( + txn=txn, + table="users", + keyvalues={"name": user_id}, + updatevalues={"suspended": 1 if suspended else 0}, + ) + self._invalidate_cache_and_stream( + txn, self.get_user_suspended_status, (user_id,) + ) + self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) + async def set_user_locked_status(self, user_id: str, locked: bool) -> None: """Set the `locked` property for the provided user to the provided value. diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py index 505465d529e..14e3871dc1c 100644 --- a/tests/storage/test_registration.py +++ b/tests/storage/test_registration.py @@ -43,7 +43,6 @@ def test_register(self) -> None: self.assertEqual( UserInfo( - # TODO(paul): Surely this field should be 'user_id', not 'name' user_id=UserID.from_string(self.user_id), is_admin=False, is_guest=False, @@ -57,6 +56,7 @@ def test_register(self) -> None: locked=False, is_shadow_banned=False, approved=True, + suspended=False, ), (self.get_success(self.store.get_user_by_id(self.user_id))), )