Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Add type hints to tests files #12256

Merged
merged 3 commits into from
Mar 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/12256.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add type hints to tests files.
2 changes: 0 additions & 2 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,7 @@ exclude = (?x)
|tests/server.py
|tests/server_notices/test_resource_limits_server_notices.py
|tests/state/test_v2.py
|tests/storage/test_background_update.py
|tests/storage/test_base.py
|tests/storage/test_id_generators.py
|tests/storage/test_roommember.py
|tests/test_metrics.py
|tests/test_phone_home.py
Expand Down
35 changes: 20 additions & 15 deletions tests/handlers/test_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,14 @@
from unittest.mock import ANY, Mock, call

from twisted.internet import defer
from twisted.test.proto_helpers import MemoryReactor
from twisted.web.resource import Resource

from synapse.api.errors import AuthError
from synapse.federation.transport.server import TransportLayerServer
from synapse.types import UserID, create_requester
from synapse.server import HomeServer
from synapse.types import JsonDict, UserID, create_requester
from synapse.util import Clock

from tests import unittest
from tests.test_utils import make_awaitable
Expand All @@ -42,7 +45,9 @@
OTHER_ROOM_ID = "another-room"


def _expect_edu_transaction(edu_type, content, origin="test"):
def _expect_edu_transaction(
edu_type: str, content: JsonDict, origin: str = "test"
) -> JsonDict:
return {
"origin": origin,
"origin_server_ts": 1000000,
Expand All @@ -51,12 +56,12 @@ def _expect_edu_transaction(edu_type, content, origin="test"):
}


def _make_edu_transaction_json(edu_type, content):
def _make_edu_transaction_json(edu_type: str, content: JsonDict) -> bytes:
return json.dumps(_expect_edu_transaction(edu_type, content)).encode("utf8")


class TypingNotificationsTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
# we mock out the keyring so as to skip the authentication check on the
# federation API call.
mock_keyring = Mock(spec=["verify_json_for_server"])
Expand All @@ -83,7 +88,7 @@ def create_resource_dict(self) -> Dict[str, Resource]:
d["/_matrix/federation"] = TransportLayerServer(self.hs)
return d

def prepare(self, reactor, clock, hs):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
mock_notifier = hs.get_notifier()
self.on_new_event = mock_notifier.on_new_event

Expand Down Expand Up @@ -111,24 +116,24 @@ def get_received_txn_response(*args):

self.room_members = []

async def check_user_in_room(room_id, user_id):
async def check_user_in_room(room_id: str, user_id: str) -> None:
if user_id not in [u.to_string() for u in self.room_members]:
raise AuthError(401, "User is not in the room")
return None

hs.get_auth().check_user_in_room = check_user_in_room

async def check_host_in_room(room_id, server_name):
async def check_host_in_room(room_id: str, server_name: str) -> bool:
return room_id == ROOM_ID

hs.get_event_auth_handler().check_host_in_room = check_host_in_room

def get_joined_hosts_for_room(room_id):
def get_joined_hosts_for_room(room_id: str):
return {member.domain for member in self.room_members}

self.datastore.get_joined_hosts_for_room = get_joined_hosts_for_room

async def get_users_in_room(room_id):
async def get_users_in_room(room_id: str):
return {str(u) for u in self.room_members}

self.datastore.get_users_in_room = get_users_in_room
Expand All @@ -153,7 +158,7 @@ async def get_users_in_room(room_id):
lambda *args, **kwargs: make_awaitable(None)
)

def test_started_typing_local(self):
def test_started_typing_local(self) -> None:
self.room_members = [U_APPLE, U_BANANA]

self.assertEqual(self.event_source.get_current_key(), 0)
Expand Down Expand Up @@ -187,7 +192,7 @@ def test_started_typing_local(self):
)

@override_config({"send_federation": True})
def test_started_typing_remote_send(self):
def test_started_typing_remote_send(self) -> None:
self.room_members = [U_APPLE, U_ONION]

self.get_success(
Expand Down Expand Up @@ -217,7 +222,7 @@ def test_started_typing_remote_send(self):
try_trailing_slash_on_400=True,
)

def test_started_typing_remote_recv(self):
def test_started_typing_remote_recv(self) -> None:
self.room_members = [U_APPLE, U_ONION]

self.assertEqual(self.event_source.get_current_key(), 0)
Expand Down Expand Up @@ -256,7 +261,7 @@ def test_started_typing_remote_recv(self):
],
)

def test_started_typing_remote_recv_not_in_room(self):
def test_started_typing_remote_recv_not_in_room(self) -> None:
self.room_members = [U_APPLE, U_ONION]

self.assertEqual(self.event_source.get_current_key(), 0)
Expand Down Expand Up @@ -292,7 +297,7 @@ def test_started_typing_remote_recv_not_in_room(self):
self.assertEqual(events[1], 0)

@override_config({"send_federation": True})
def test_stopped_typing(self):
def test_stopped_typing(self) -> None:
self.room_members = [U_APPLE, U_BANANA, U_ONION]

# Gut-wrenching
Expand Down Expand Up @@ -343,7 +348,7 @@ def test_stopped_typing(self):
[{"type": "m.typing", "room_id": ROOM_ID, "content": {"user_ids": []}}],
)

def test_typing_timeout(self):
def test_typing_timeout(self) -> None:
self.room_members = [U_APPLE, U_BANANA]

self.assertEqual(self.event_source.get_current_key(), 0)
Expand Down
23 changes: 12 additions & 11 deletions tests/push/test_push_rule_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Dict
from typing import Dict, Optional, Union

import frozendict

from synapse.api.room_versions import RoomVersions
from synapse.events import FrozenEvent
from synapse.push import push_rule_evaluator
from synapse.push.push_rule_evaluator import PushRuleEvaluatorForEvent
from synapse.types import JsonDict

from tests import unittest


class PushRuleEvaluatorTestCase(unittest.TestCase):
def _get_evaluator(self, content):
def _get_evaluator(self, content: JsonDict) -> PushRuleEvaluatorForEvent:
event = FrozenEvent(
{
"event_id": "$event_id",
Expand All @@ -39,12 +40,12 @@ def _get_evaluator(self, content):
)
room_member_count = 0
sender_power_level = 0
power_levels = {}
power_levels: Dict[str, Union[int, Dict[str, int]]] = {}
return PushRuleEvaluatorForEvent(
event, room_member_count, sender_power_level, power_levels
)

def test_display_name(self):
def test_display_name(self) -> None:
"""Check for a matching display name in the body of the event."""
evaluator = self._get_evaluator({"body": "foo bar baz"})

Expand All @@ -71,20 +72,20 @@ def test_display_name(self):
self.assertTrue(evaluator.matches(condition, "@user:test", "foo bar"))

def _assert_matches(
self, condition: Dict[str, Any], content: Dict[str, Any], msg=None
self, condition: JsonDict, content: JsonDict, msg: Optional[str] = None
) -> None:
evaluator = self._get_evaluator(content)
self.assertTrue(evaluator.matches(condition, "@user:test", "display_name"), msg)

def _assert_not_matches(
self, condition: Dict[str, Any], content: Dict[str, Any], msg=None
self, condition: JsonDict, content: JsonDict, msg: Optional[str] = None
) -> None:
evaluator = self._get_evaluator(content)
self.assertFalse(
evaluator.matches(condition, "@user:test", "display_name"), msg
)

def test_event_match_body(self):
def test_event_match_body(self) -> None:
"""Check that event_match conditions on content.body work as expected"""

# if the key is `content.body`, the pattern matches substrings.
Expand Down Expand Up @@ -165,7 +166,7 @@ def test_event_match_body(self):
r"? after \ should match any character",
)

def test_event_match_non_body(self):
def test_event_match_non_body(self) -> None:
"""Check that event_match conditions on other keys work as expected"""

# if the key is anything other than 'content.body', the pattern must match the
Expand Down Expand Up @@ -241,7 +242,7 @@ def test_event_match_non_body(self):
"pattern should not match before a newline",
)

def test_no_body(self):
def test_no_body(self) -> None:
"""Not having a body shouldn't break the evaluator."""
evaluator = self._get_evaluator({})

Expand All @@ -250,7 +251,7 @@ def test_no_body(self):
}
self.assertFalse(evaluator.matches(condition, "@user:test", "foo"))

def test_invalid_body(self):
def test_invalid_body(self) -> None:
"""A non-string body should not break the evaluator."""
condition = {
"kind": "contains_display_name",
Expand All @@ -260,7 +261,7 @@ def test_invalid_body(self):
evaluator = self._get_evaluator({"body": body})
self.assertFalse(evaluator.matches(condition, "@user:test", "foo"))

def test_tweaks_for_actions(self):
def test_tweaks_for_actions(self) -> None:
"""
This tests the behaviour of tweaks_for_actions.
"""
Expand Down
Loading