Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Add type hints for tests/unittest.py. (#12347)
Browse files Browse the repository at this point in the history
In particular, add type hints for get_success and friends, which are then helpful in a bunch of places.
  • Loading branch information
richvdh authored Apr 1, 2022
1 parent 33ebee4 commit f0b0318
Show file tree
Hide file tree
Showing 12 changed files with 97 additions and 48 deletions.
1 change: 1 addition & 0 deletions changelog.d/12347.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add type annotations for `tests/unittest.py`.
1 change: 0 additions & 1 deletion mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,6 @@ exclude = (?x)
|tests/test_server.py
|tests/test_state.py
|tests/test_terms_auth.py
|tests/unittest.py
|tests/util/caches/test_cached_call.py
|tests/util/caches/test_deferred_cache.py
|tests/util/caches/test_descriptors.py
Expand Down
6 changes: 4 additions & 2 deletions tests/handlers/test_e2e_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,8 +463,10 @@ def test_self_signing_key_doesnt_show_up_as_device(self) -> None:
res = e.value.code
self.assertEqual(res, 400)

res = self.get_success(self.handler.query_local_devices({local_user: None}))
self.assertDictEqual(res, {local_user: {}})
query_res = self.get_success(
self.handler.query_local_devices({local_user: None})
)
self.assertDictEqual(query_res, {local_user: {}})

def test_upload_signatures(self) -> None:
"""should check signatures that are uploaded"""
Expand Down
5 changes: 3 additions & 2 deletions tests/handlers/test_federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,8 @@ def test_backfill_floating_outlier_membership_auth(self) -> None:
member_event.signatures = member_event_dict["signatures"]

# Add the new member_event to the StateMap
prev_state_map[
updated_state_map = dict(prev_state_map)
updated_state_map[
(member_event.type, member_event.state_key)
] = member_event.event_id
auth_events.append(member_event)
Expand All @@ -399,7 +400,7 @@ def test_backfill_floating_outlier_membership_auth(self) -> None:
prev_event_ids=message_event_dict["prev_events"],
auth_event_ids=self._event_auth_handler.compute_auth_events(
builder,
prev_state_map,
updated_state_map,
for_verification=False,
),
depth=message_event_dict["depth"],
Expand Down
7 changes: 4 additions & 3 deletions tests/handlers/test_oidc.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,10 +354,11 @@ def test_redirect_request(self) -> None:
req = Mock(spec=["cookies"])
req.cookies = []

url = self.get_success(
self.provider.handle_redirect_request(req, b"http://client/redirect")
url = urlparse(
self.get_success(
self.provider.handle_redirect_request(req, b"http://client/redirect")
)
)
url = urlparse(url)
auth_endpoint = urlparse(AUTHORIZATION_ENDPOINT)

self.assertEqual(url.scheme, auth_endpoint.scheme)
Expand Down
2 changes: 2 additions & 0 deletions tests/handlers/test_user_directory.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,7 @@ def test_handle_local_profile_change_with_support_user(self) -> None:
self.handler.handle_local_profile_change(regular_user_id, profile_info)
)
profile = self.get_success(self.store.get_user_in_directory(regular_user_id))
assert profile is not None
self.assertTrue(profile["display_name"] == display_name)

def test_handle_local_profile_change_with_deactivated_user(self) -> None:
Expand All @@ -369,6 +370,7 @@ def test_handle_local_profile_change_with_deactivated_user(self) -> None:

# profile is in directory
profile = self.get_success(self.store.get_user_in_directory(r_user_id))
assert profile is not None
self.assertTrue(profile["display_name"] == display_name)

# deactivate user
Expand Down
8 changes: 8 additions & 0 deletions tests/rest/admin/test_media.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,6 +702,7 @@ def test_quarantine_media(self) -> None:
"""

media_info = self.get_success(self.store.get_local_media(self.media_id))
assert media_info is not None
self.assertFalse(media_info["quarantined_by"])

# quarantining
Expand All @@ -715,6 +716,7 @@ def test_quarantine_media(self) -> None:
self.assertFalse(channel.json_body)

media_info = self.get_success(self.store.get_local_media(self.media_id))
assert media_info is not None
self.assertTrue(media_info["quarantined_by"])

# remove from quarantine
Expand All @@ -728,6 +730,7 @@ def test_quarantine_media(self) -> None:
self.assertFalse(channel.json_body)

media_info = self.get_success(self.store.get_local_media(self.media_id))
assert media_info is not None
self.assertFalse(media_info["quarantined_by"])

def test_quarantine_protected_media(self) -> None:
Expand All @@ -740,6 +743,7 @@ def test_quarantine_protected_media(self) -> None:

# verify protection
media_info = self.get_success(self.store.get_local_media(self.media_id))
assert media_info is not None
self.assertTrue(media_info["safe_from_quarantine"])

# quarantining
Expand All @@ -754,6 +758,7 @@ def test_quarantine_protected_media(self) -> None:

# verify that is not in quarantine
media_info = self.get_success(self.store.get_local_media(self.media_id))
assert media_info is not None
self.assertFalse(media_info["quarantined_by"])


Expand Down Expand Up @@ -830,6 +835,7 @@ def test_protect_media(self) -> None:
"""

media_info = self.get_success(self.store.get_local_media(self.media_id))
assert media_info is not None
self.assertFalse(media_info["safe_from_quarantine"])

# protect
Expand All @@ -843,6 +849,7 @@ def test_protect_media(self) -> None:
self.assertFalse(channel.json_body)

media_info = self.get_success(self.store.get_local_media(self.media_id))
assert media_info is not None
self.assertTrue(media_info["safe_from_quarantine"])

# unprotect
Expand All @@ -856,6 +863,7 @@ def test_protect_media(self) -> None:
self.assertFalse(channel.json_body)

media_info = self.get_success(self.store.get_local_media(self.media_id))
assert media_info is not None
self.assertFalse(media_info["safe_from_quarantine"])


Expand Down
15 changes: 9 additions & 6 deletions tests/rest/admin/test_user.py
Original file line number Diff line number Diff line change
Expand Up @@ -1590,10 +1590,9 @@ def test_create_user_email_notif_for_new_users(self) -> None:
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])

pushers = self.get_success(
self.store.get_pushers_by({"user_name": "@bob:test"})
pushers = list(
self.get_success(self.store.get_pushers_by({"user_name": "@bob:test"}))
)
pushers = list(pushers)
self.assertEqual(len(pushers), 1)
self.assertEqual("@bob:test", pushers[0].user_name)

Expand Down Expand Up @@ -1632,10 +1631,9 @@ def test_create_user_email_no_notif_for_new_users(self) -> None:
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])

pushers = self.get_success(
self.store.get_pushers_by({"user_name": "@bob:test"})
pushers = list(
self.get_success(self.store.get_pushers_by({"user_name": "@bob:test"}))
)
pushers = list(pushers)
self.assertEqual(len(pushers), 0)

def test_set_password(self) -> None:
Expand Down Expand Up @@ -2144,6 +2142,7 @@ def test_change_name_deactivate_user_user_directory(self) -> None:

# is in user directory
profile = self.get_success(self.store.get_user_in_directory(self.other_user))
assert profile is not None
self.assertTrue(profile["display_name"] == "User")

# Deactivate user
Expand Down Expand Up @@ -2711,6 +2710,7 @@ def test_get_pushers(self) -> None:
user_tuple = self.get_success(
self.store.get_user_by_access_token(other_user_token)
)
assert user_tuple is not None
token_id = user_tuple.token_id

self.get_success(
Expand Down Expand Up @@ -3676,6 +3676,7 @@ def test_success(self) -> None:
# The user starts off as not shadow-banned.
other_user_token = self.login("user", "pass")
result = self.get_success(self.store.get_user_by_access_token(other_user_token))
assert result is not None
self.assertFalse(result.shadow_banned)

channel = self.make_request("POST", self.url, access_token=self.admin_user_tok)
Expand All @@ -3684,6 +3685,7 @@ def test_success(self) -> None:

# Ensure the user is shadow-banned (and the cache was cleared).
result = self.get_success(self.store.get_user_by_access_token(other_user_token))
assert result is not None
self.assertTrue(result.shadow_banned)

# Un-shadow-ban the user.
Expand All @@ -3695,6 +3697,7 @@ def test_success(self) -> None:

# Ensure the user is no longer shadow-banned (and the cache was cleared).
result = self.get_success(self.store.get_user_by_access_token(other_user_token))
assert result is not None
self.assertFalse(result.shadow_banned)


Expand Down
6 changes: 4 additions & 2 deletions tests/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from collections import deque
from io import SEEK_END, BytesIO
from typing import (
AnyStr,
Callable,
Dict,
Iterable,
Expand Down Expand Up @@ -86,6 +85,9 @@

logger = logging.getLogger(__name__)

# the type of thing that can be passed into `make_request` in the headers list
CustomHeaderType = Tuple[Union[str, bytes], Union[str, bytes]]


class TimedOutException(Exception):
"""
Expand Down Expand Up @@ -260,7 +262,7 @@ def make_request(
federation_auth_origin: Optional[bytes] = None,
content_is_form: bool = False,
await_result: bool = True,
custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
custom_headers: Optional[Iterable[CustomHeaderType]] = None,
client_ip: str = "127.0.0.1",
) -> FakeChannel:
"""
Expand Down
8 changes: 4 additions & 4 deletions tests/storage/databases/main/test_lock.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def test_simple_lock(self):
"""
# First to acquire this lock, so it should complete
lock = self.get_success(self.store.try_acquire_lock("name", "key"))
self.assertIsNotNone(lock)
assert lock is not None

# Enter the context manager
self.get_success(lock.__aenter__())
Expand All @@ -45,15 +45,15 @@ def test_simple_lock(self):

# We can now acquire the lock again.
lock3 = self.get_success(self.store.try_acquire_lock("name", "key"))
self.assertIsNotNone(lock3)
assert lock3 is not None
self.get_success(lock3.__aenter__())
self.get_success(lock3.__aexit__(None, None, None))

def test_maintain_lock(self):
"""Test that we don't time out locks while they're still active"""

lock = self.get_success(self.store.try_acquire_lock("name", "key"))
self.assertIsNotNone(lock)
assert lock is not None

self.get_success(lock.__aenter__())

Expand All @@ -69,7 +69,7 @@ def test_timeout_lock(self):
"""Test that we time out locks if they're not updated for ages"""

lock = self.get_success(self.store.try_acquire_lock("name", "key"))
self.assertIsNotNone(lock)
assert lock is not None

self.get_success(lock.__aenter__())

Expand Down
1 change: 1 addition & 0 deletions tests/storage/test_appservice.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,7 @@ def test_get_oldest_unsent_txn(self) -> None:
self.get_success(self._insert_txn(service.id, 12, other_events))

txn = self.get_success(self.store.get_oldest_unsent_txn(service))
assert txn is not None
self.assertEqual(service, txn.service)
self.assertEqual(10, txn.id)
self.assertEqual(events, txn.events)
Expand Down
Loading

0 comments on commit f0b0318

Please sign in to comment.