Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Add type hints to more tests files. (#12240)
Browse files Browse the repository at this point in the history
  • Loading branch information
dklimpel authored Mar 17, 2022
1 parent 3f7cfbc commit 9e06e22
Show file tree
Hide file tree
Showing 6 changed files with 66 additions and 47 deletions.
1 change: 1 addition & 0 deletions changelog.d/12240.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add type hints to tests files.
4 changes: 0 additions & 4 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,6 @@ exclude = (?x)
|tests/federation/test_federation_server.py
|tests/federation/transport/test_knocking.py
|tests/federation/transport/test_server.py
|tests/handlers/test_cas.py
|tests/handlers/test_federation.py
|tests/handlers/test_presence.py
|tests/handlers/test_typing.py
|tests/http/federation/test_matrix_federation_agent.py
|tests/http/federation/test_srv_resolver.py
Expand All @@ -80,7 +77,6 @@ exclude = (?x)
|tests/logging/test_terse_json.py
|tests/module_api/test_api.py
|tests/push/test_email.py
|tests/push/test_http.py
|tests/push/test_presentable_names.py
|tests/push/test_push_rule_evaluator.py
|tests/rest/client/test_transactions.py
Expand Down
19 changes: 12 additions & 7 deletions tests/handlers/test_cas.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict
from unittest.mock import Mock

from twisted.test.proto_helpers import MemoryReactor

from synapse.handlers.cas import CasResponse
from synapse.server import HomeServer
from synapse.util import Clock

from tests.test_utils import simple_async_mock
from tests.unittest import HomeserverTestCase, override_config
Expand All @@ -24,7 +29,7 @@


class CasHandlerTestCase(HomeserverTestCase):
def default_config(self):
def default_config(self) -> Dict[str, Any]:
config = super().default_config()
config["public_baseurl"] = BASE_URL
cas_config = {
Expand All @@ -40,7 +45,7 @@ def default_config(self):

return config

def make_homeserver(self, reactor, clock):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
hs = self.setup_test_homeserver()

self.handler = hs.get_cas_handler()
Expand All @@ -51,7 +56,7 @@ def make_homeserver(self, reactor, clock):

return hs

def test_map_cas_user_to_user(self):
def test_map_cas_user_to_user(self) -> None:
"""Ensure that mapping the CAS user returned from a provider to an MXID works properly."""

# stub out the auth handler
Expand All @@ -75,7 +80,7 @@ def test_map_cas_user_to_user(self):
auth_provider_session_id=None,
)

def test_map_cas_user_to_existing_user(self):
def test_map_cas_user_to_existing_user(self) -> None:
"""Existing users can log in with CAS account."""
store = self.hs.get_datastores().main
self.get_success(
Expand Down Expand Up @@ -119,7 +124,7 @@ def test_map_cas_user_to_existing_user(self):
auth_provider_session_id=None,
)

def test_map_cas_user_to_invalid_localpart(self):
def test_map_cas_user_to_invalid_localpart(self) -> None:
"""CAS automaps invalid characters to base-64 encoding."""

# stub out the auth handler
Expand Down Expand Up @@ -150,7 +155,7 @@ def test_map_cas_user_to_invalid_localpart(self):
}
}
)
def test_required_attributes(self):
def test_required_attributes(self) -> None:
"""The required attributes must be met from the CAS response."""

# stub out the auth handler
Expand All @@ -166,7 +171,7 @@ def test_required_attributes(self):
auth_handler.complete_sso_login.assert_not_called()

# The response doesn't have any department.
cas_response = CasResponse("test_user", {"userGroup": "staff"})
cas_response = CasResponse("test_user", {"userGroup": ["staff"]})
request.reset_mock()
self.get_success(
self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
Expand Down
36 changes: 21 additions & 15 deletions tests/handlers/test_federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import List
from typing import List, cast
from unittest import TestCase

from twisted.test.proto_helpers import MemoryReactor

from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError, Codes, LimitExceededError, SynapseError
from synapse.api.room_versions import RoomVersions
Expand All @@ -23,7 +25,9 @@
from synapse.logging.context import LoggingContext, run_in_background
from synapse.rest import admin
from synapse.rest.client import login, room
from synapse.server import HomeServer
from synapse.types import create_requester
from synapse.util import Clock
from synapse.util.stringutils import random_string

from tests import unittest
Expand All @@ -42,15 +46,15 @@ class FederationTestCase(unittest.HomeserverTestCase):
room.register_servlets,
]

def make_homeserver(self, reactor, clock):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
hs = self.setup_test_homeserver(federation_http_client=None)
self.handler = hs.get_federation_handler()
self.store = hs.get_datastores().main
self.state_store = hs.get_storage().state
self._event_auth_handler = hs.get_event_auth_handler()
return hs

def test_exchange_revoked_invite(self):
def test_exchange_revoked_invite(self) -> None:
user_id = self.register_user("kermit", "test")
tok = self.login("kermit", "test")

Expand Down Expand Up @@ -96,7 +100,7 @@ def test_exchange_revoked_invite(self):
self.assertEqual(failure.errcode, Codes.FORBIDDEN, failure)
self.assertEqual(failure.msg, "You are not invited to this room.")

def test_rejected_message_event_state(self):
def test_rejected_message_event_state(self) -> None:
"""
Check that we store the state group correctly for rejected non-state events.
Expand Down Expand Up @@ -126,7 +130,7 @@ def test_rejected_message_event_state(self):
"content": {},
"room_id": room_id,
"sender": "@yetanotheruser:" + OTHER_SERVER,
"depth": join_event["depth"] + 1,
"depth": cast(int, join_event["depth"]) + 1,
"prev_events": [join_event.event_id],
"auth_events": [],
"origin_server_ts": self.clock.time_msec(),
Expand All @@ -149,7 +153,7 @@ def test_rejected_message_event_state(self):

self.assertEqual(sg, sg2)

def test_rejected_state_event_state(self):
def test_rejected_state_event_state(self) -> None:
"""
Check that we store the state group correctly for rejected state events.
Expand Down Expand Up @@ -180,7 +184,7 @@ def test_rejected_state_event_state(self):
"content": {},
"room_id": room_id,
"sender": "@yetanotheruser:" + OTHER_SERVER,
"depth": join_event["depth"] + 1,
"depth": cast(int, join_event["depth"]) + 1,
"prev_events": [join_event.event_id],
"auth_events": [],
"origin_server_ts": self.clock.time_msec(),
Expand All @@ -203,7 +207,7 @@ def test_rejected_state_event_state(self):

self.assertEqual(sg, sg2)

def test_backfill_with_many_backward_extremities(self):
def test_backfill_with_many_backward_extremities(self) -> None:
"""
Check that we can backfill with many backward extremities.
The goal is to make sure that when we only use a portion
Expand Down Expand Up @@ -262,7 +266,7 @@ def test_backfill_with_many_backward_extremities(self):
)
self.get_success(d)

def test_backfill_floating_outlier_membership_auth(self):
def test_backfill_floating_outlier_membership_auth(self) -> None:
"""
As the local homeserver, check that we can properly process a federated
event from the OTHER_SERVER with auth_events that include a floating
Expand Down Expand Up @@ -377,7 +381,7 @@ async def get_event_auth(
for ae in auth_events
]

self.handler.federation_client.get_event_auth = get_event_auth
self.handler.federation_client.get_event_auth = get_event_auth # type: ignore[assignment]

with LoggingContext("receive_pdu"):
# Fake the OTHER_SERVER federating the message event over to our local homeserver
Expand All @@ -397,7 +401,7 @@ async def get_event_auth(
@unittest.override_config(
{"rc_invites": {"per_user": {"per_second": 0.5, "burst_count": 3}}}
)
def test_invite_by_user_ratelimit(self):
def test_invite_by_user_ratelimit(self) -> None:
"""Tests that invites from federation to a particular user are
actually rate-limited.
"""
Expand Down Expand Up @@ -446,7 +450,9 @@ def create_invite():
exc=LimitExceededError,
)

def _build_and_send_join_event(self, other_server, other_user, room_id):
def _build_and_send_join_event(
self, other_server: str, other_user: str, room_id: str
) -> EventBase:
join_event = self.get_success(
self.handler.on_make_join_request(other_server, room_id, other_user)
)
Expand All @@ -469,7 +475,7 @@ def _build_and_send_join_event(self, other_server, other_user, room_id):


class EventFromPduTestCase(TestCase):
def test_valid_json(self):
def test_valid_json(self) -> None:
"""Valid JSON should be turned into an event."""
ev = event_from_pdu_json(
{
Expand All @@ -487,7 +493,7 @@ def test_valid_json(self):

self.assertIsInstance(ev, EventBase)

def test_invalid_numbers(self):
def test_invalid_numbers(self) -> None:
"""Invalid values for an integer should be rejected, all floats should be rejected."""
for value in [
-(2 ** 53),
Expand All @@ -512,7 +518,7 @@ def test_invalid_numbers(self):
RoomVersions.V6,
)

def test_invalid_nested(self):
def test_invalid_nested(self) -> None:
"""List and dictionaries are recursively searched."""
with self.assertRaises(SynapseError):
event_from_pdu_json(
Expand Down
13 changes: 9 additions & 4 deletions tests/handlers/test_presence.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,11 +331,11 @@ def test_persisting_presence_updates(self):

# Extract presence update user ID and state information into lists of tuples
db_presence_states = [(ps[0], ps[1]) for _, ps in db_presence_states[0]]
presence_states = [(ps.user_id, ps.state) for ps in presence_states]
presence_states_compare = [(ps.user_id, ps.state) for ps in presence_states]

# Compare what we put into the storage with what we got out.
# They should be identical.
self.assertEqual(presence_states, db_presence_states)
self.assertEqual(presence_states_compare, db_presence_states)


class PresenceTimeoutTestCase(unittest.TestCase):
Expand All @@ -357,6 +357,7 @@ def test_idle_timer(self):
new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now)

self.assertIsNotNone(new_state)
assert new_state is not None
self.assertEqual(new_state.state, PresenceState.UNAVAILABLE)
self.assertEqual(new_state.status_msg, status_msg)

Expand All @@ -380,6 +381,7 @@ def test_busy_no_idle(self):
new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now)

self.assertIsNotNone(new_state)
assert new_state is not None
self.assertEqual(new_state.state, PresenceState.BUSY)
self.assertEqual(new_state.status_msg, status_msg)

Expand All @@ -399,6 +401,7 @@ def test_sync_timeout(self):
new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now)

self.assertIsNotNone(new_state)
assert new_state is not None
self.assertEqual(new_state.state, PresenceState.OFFLINE)
self.assertEqual(new_state.status_msg, status_msg)

Expand All @@ -420,6 +423,7 @@ def test_sync_online(self):
)

self.assertIsNotNone(new_state)
assert new_state is not None
self.assertEqual(new_state.state, PresenceState.ONLINE)
self.assertEqual(new_state.status_msg, status_msg)

Expand Down Expand Up @@ -477,6 +481,7 @@ def test_federation_timeout(self):
)

self.assertIsNotNone(new_state)
assert new_state is not None
self.assertEqual(new_state.state, PresenceState.OFFLINE)
self.assertEqual(new_state.status_msg, status_msg)

Expand Down Expand Up @@ -653,13 +658,13 @@ def test_set_presence_with_status_msg_none(self):
self._set_presencestate_with_status_msg(user_id, PresenceState.ONLINE, None)

def _set_presencestate_with_status_msg(
self, user_id: str, state: PresenceState, status_msg: Optional[str]
self, user_id: str, state: str, status_msg: Optional[str]
):
"""Set a PresenceState and status_msg and check the result.
Args:
user_id: User for that the status is to be set.
PresenceState: The new PresenceState.
state: The new PresenceState.
status_msg: Status message that is to be set.
"""
self.get_success(
Expand Down
Loading

0 comments on commit 9e06e22

Please sign in to comment.