From f8f6eea3347b6bf3b4e5710fd133ae9a4ed71c93 Mon Sep 17 00:00:00 2001 From: David Teller Date: Wed, 11 May 2022 12:06:08 +0200 Subject: [PATCH] WIP: Ported existing tests --- tests/handlers/test_user_directory.py | 32 +++++- tests/rest/client/test_rooms.py | 148 ++++++++++++++++++++++++-- 2 files changed, 171 insertions(+), 9 deletions(-) diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py index 96e2e3039ba8..93dd2b6880ae 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py @@ -16,6 +16,7 @@ from urllib.parse import quote from twisted.test.proto_helpers import MemoryReactor +from synapse.api.errors import Code import synapse.rest.admin from synapse.api.constants import UserTypes @@ -23,6 +24,7 @@ from synapse.appservice import ApplicationService from synapse.rest.client import login, register, room, user_directory from synapse.server import HomeServer +from synapse.spam_checker_api import ALLOW, Decision from synapse.storage.roommember import ProfileInfo from synapse.types import create_requester from synapse.util import Clock @@ -773,12 +775,24 @@ def test_spam_checker(self) -> None: s = self.get_success(self.handler.search_users(u1, "user2", 10)) self.assertEqual(len(s["results"]), 1) - async def allow_all(user_profile: ProfileInfo) -> bool: + async def allow_all_old(user_profile: ProfileInfo) -> bool: # Allow all users. return False - # Configure a spam checker that does not filter any users. + # Configure a spam checker that does not filter any users (old-style) spam_checker = self.hs.get_spam_checker() + spam_checker._check_username_for_spam_callbacks = [allow_all_old] + + # The results do not change: + # We get one search result when searching for user2 by user1. + s = self.get_success(self.handler.search_users(u1, "user2", 10)) + self.assertEqual(len(s["results"]), 1) + + async def allow_all(user_profile: ProfileInfo) -> Decision: + # Allow all users. + return ALLOW + + # Configure a spam checker that does not filter any users spam_checker._check_username_for_spam_callbacks = [allow_all] # The results do not change: @@ -787,16 +801,28 @@ async def allow_all(user_profile: ProfileInfo) -> bool: self.assertEqual(len(s["results"]), 1) # Configure a spam checker that filters all users. - async def block_all(user_profile: ProfileInfo) -> bool: + async def block_all_old(user_profile: ProfileInfo) -> bool: # All users are spammy. return True + spam_checker._check_username_for_spam_callbacks = [block_all_old] + + # User1 now gets no search results for any of the other users. + s = self.get_success(self.handler.search_users(u1, "user2", 10)) + self.assertEqual(len(s["results"]), 0) + + # Configure a spam checker that filters all users. + async def block_all(user_profile: ProfileInfo) -> bool: + # All users are spammy. + return Code.FORBIDDEN + spam_checker._check_username_for_spam_callbacks = [block_all] # User1 now gets no search results for any of the other users. s = self.get_success(self.handler.search_users(u1, "user2", 10)) self.assertEqual(len(s["results"]), 0) + def test_legacy_spam_checker(self) -> None: """ A spam checker without the expected method should be ignored. diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index 9443daa0560a..24bfde2e37f4 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -31,11 +31,12 @@ Membership, RelationTypes, ) -from synapse.api.errors import Codes, HttpResponseException +from synapse.api.errors import Code, Codes, HttpResponseException from synapse.handlers.pagination import PurgeStatus from synapse.rest import admin from synapse.rest.client import account, directory, login, profile, room, sync from synapse.server import HomeServer +from synapse.spam_checker_api import ALLOW, Decision from synapse.types import JsonDict, RoomAlias, UserID, create_requester from synapse.util import Clock from synapse.util.stringutils import random_string @@ -676,9 +677,9 @@ def test_post_room_invitees_ratelimit(self) -> None: channel = self.make_request("POST", "/createRoom", content) self.assertEqual(200, channel.code) - def test_spam_checker_may_join_room(self) -> None: + def test_spam_checker_may_join_room_old(self) -> None: """Tests that the user_may_join_room spam checker callback is correctly bypassed - when creating a new room. + when creating a new room (old-style API, returning a boolean). """ async def user_may_join_room( @@ -700,6 +701,29 @@ async def user_may_join_room( self.assertEqual(join_mock.call_count, 0) + def test_spam_checker_may_join_room(self) -> None: + """Tests that the user_may_join_room spam checker callback is correctly bypassed + when creating a new room. + """ + + async def user_may_join_room( + mxid: str, + room_id: str, + is_invite: bool, + ) -> Decision: + return Code.FORBIDDEN + + join_mock = Mock(side_effect=user_may_join_room) + self.hs.get_spam_checker()._user_may_join_room_callbacks.append(join_mock) + + channel = self.make_request( + "POST", + "/createRoom", + {}, + ) + self.assertEqual(channel.code, 200, channel.json_body) + + self.assertEqual(join_mock.call_count, 0) class RoomTopicTestCase(RoomBase): """Tests /rooms/$room_id/topic REST events.""" @@ -910,9 +934,9 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.room2 = self.helper.create_room_as(room_creator=self.user1, tok=self.tok1) self.room3 = self.helper.create_room_as(room_creator=self.user1, tok=self.tok1) - def test_spam_checker_may_join_room(self) -> None: + def test_spam_checker_may_join_room_old(self) -> None: """Tests that the user_may_join_room spam checker callback is correctly called - and blocks room joins when needed. + and blocks room joins when needed (old-style API, return a boolean). """ # Register a dummy callback. Make it allow all room joins for now. @@ -967,6 +991,63 @@ async def user_may_join_room( return_value = False self.helper.join(self.room3, self.user2, expect_code=403, tok=self.tok2) + def test_spam_checker_may_join_room(self) -> None: + """Tests that the user_may_join_room spam checker callback is correctly called + and blocks room joins when needed. + """ + + # Register a dummy callback. Make it allow all room joins for now. + return_value = ALLOW + + async def user_may_join_room( + userid: str, + room_id: str, + is_invited: bool, + ) -> Decision: + return return_value + + callback_mock = Mock(side_effect=user_may_join_room) + self.hs.get_spam_checker()._user_may_join_room_callbacks.append(callback_mock) + + # Join a first room, without being invited to it. + self.helper.join(self.room1, self.user2, tok=self.tok2) + + # Check that the callback was called with the right arguments. + expected_call_args = ( + ( + self.user2, + self.room1, + False, + ), + ) + self.assertEqual( + callback_mock.call_args, + expected_call_args, + callback_mock.call_args, + ) + + # Join a second room, this time with an invite for it. + self.helper.invite(self.room2, self.user1, self.user2, tok=self.tok1) + self.helper.join(self.room2, self.user2, tok=self.tok2) + + # Check that the callback was called with the right arguments. + expected_call_args = ( + ( + self.user2, + self.room2, + True, + ), + ) + self.assertEqual( + callback_mock.call_args, + expected_call_args, + callback_mock.call_args, + ) + + # Now make the callback deny all room joins, and check that a join actually fails. + return_value = Code.FORBIDDEN + self.helper.join(self.room3, self.user2, expect_code=403, tok=self.tok2) + class RoomJoinRatelimitTestCase(RoomBase): user_id = "@sid1:red" @@ -2586,7 +2667,7 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok) - def test_threepid_invite_spamcheck(self) -> None: + def test_threepid_invite_spamcheck_old(self) -> None: # Mock a few functions to prevent the test from failing due to failing to talk to # a remote IS. We keep the mock for _mock_make_and_store_3pid_invite around so we # can check its call_count later on during the test. @@ -2640,3 +2721,58 @@ def test_threepid_invite_spamcheck(self) -> None: # Also check that it stopped before calling _make_and_store_3pid_invite. make_invite_mock.assert_called_once() + + def test_threepid_invite_spamcheck(self) -> None: + # Mock a few functions to prevent the test from failing due to failing to talk to + # a remote IS. We keep the mock for _mock_make_and_store_3pid_invite around so we + # can check its call_count later on during the test. + make_invite_mock = Mock(return_value=make_awaitable(0)) + self.hs.get_room_member_handler()._make_and_store_3pid_invite = make_invite_mock + self.hs.get_identity_handler().lookup_3pid = Mock( + return_value=make_awaitable(None), + ) + + # Add a mock to the spamchecker callbacks for user_may_send_3pid_invite. Make it + # allow everything for now. + mock = Mock(return_value=make_awaitable(ALLOW)) + self.hs.get_spam_checker()._user_may_send_3pid_invite_callbacks.append(mock) + + # Send a 3PID invite into the room and check that it succeeded. + email_to_invite = "teresa@example.com" + channel = self.make_request( + method="POST", + path="/rooms/" + self.room_id + "/invite", + content={ + "id_server": "example.com", + "id_access_token": "sometoken", + "medium": "email", + "address": email_to_invite, + }, + access_token=self.tok, + ) + self.assertEqual(channel.code, 200) + + # Check that the callback was called with the right params. + mock.assert_called_with(self.user_id, "email", email_to_invite, self.room_id) + + # Check that the call to send the invite was made. + make_invite_mock.assert_called_once() + + # Now change the return value of the callback to deny any invite and test that + # we can't send the invite. + mock.return_value = make_awaitable(Code.FORBIDDEN) + channel = self.make_request( + method="POST", + path="/rooms/" + self.room_id + "/invite", + content={ + "id_server": "example.com", + "id_access_token": "sometoken", + "medium": "email", + "address": email_to_invite, + }, + access_token=self.tok, + ) + self.assertEqual(channel.code, 403) + + # Also check that it stopped before calling _make_and_store_3pid_invite. + make_invite_mock.assert_called_once()