Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Add type hints to tests/rest/client #12066

Merged
merged 4 commits into from
Feb 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/12066.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add type hints to `tests/rest/client`.
70 changes: 38 additions & 32 deletions tests/rest/client/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from http import HTTPStatus
from typing import Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union

from twisted.internet.defer import succeed
from twisted.test.proto_helpers import MemoryReactor
from twisted.web.resource import Resource

import synapse.rest.admin
from synapse.api.constants import LoginType
from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker
from synapse.rest.client import account, auth, devices, login, logout, register
from synapse.rest.synapse.client import build_synapse_client_resource_tree
from synapse.server import HomeServer
from synapse.storage.database import LoggingTransaction
from synapse.types import JsonDict, UserID
from synapse.util import Clock

from tests import unittest
from tests.handlers.test_oidc import HAS_OIDC
Expand All @@ -33,11 +37,11 @@


class DummyRecaptchaChecker(UserInteractiveAuthChecker):
def __init__(self, hs):
def __init__(self, hs: HomeServer) -> None:
super().__init__(hs)
self.recaptcha_attempts = []
self.recaptcha_attempts: List[Tuple[dict, str]] = []

def check_auth(self, authdict, clientip):
def check_auth(self, authdict: dict, clientip: str) -> Any:
self.recaptcha_attempts.append((authdict, clientip))
return succeed(True)

Expand All @@ -50,7 +54,7 @@ class FallbackAuthTests(unittest.HomeserverTestCase):
]
hijack_auth = False

def make_homeserver(self, reactor, clock):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:

config = self.default_config()

Expand All @@ -61,7 +65,7 @@ def make_homeserver(self, reactor, clock):
hs = self.setup_test_homeserver(config=config)
return hs

def prepare(self, reactor, clock, hs):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.recaptcha_checker = DummyRecaptchaChecker(hs)
auth_handler = hs.get_auth_handler()
auth_handler.checkers[LoginType.RECAPTCHA] = self.recaptcha_checker
Expand Down Expand Up @@ -101,7 +105,7 @@ def recaptcha(
self.assertEqual(len(attempts), 1)
self.assertEqual(attempts[0][0]["response"], "a")

def test_fallback_captcha(self):
def test_fallback_captcha(self) -> None:
"""Ensure that fallback auth via a captcha works."""
# Returns a 401 as per the spec
channel = self.register(
Expand Down Expand Up @@ -132,7 +136,7 @@ def test_fallback_captcha(self):
# We're given a registered user.
self.assertEqual(channel.json_body["user_id"], "@user:test")

def test_complete_operation_unknown_session(self):
def test_complete_operation_unknown_session(self) -> None:
"""
Attempting to mark an invalid session as complete should error.
"""
Expand Down Expand Up @@ -165,7 +169,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
register.register_servlets,
]

def default_config(self):
def default_config(self) -> Dict[str, Any]:
config = super().default_config()

# public_baseurl uses an http:// scheme because FakeChannel.isSecure() returns
Expand All @@ -182,12 +186,12 @@ def default_config(self):

return config

def create_resource_dict(self):
def create_resource_dict(self) -> Dict[str, Resource]:
resource_dict = super().create_resource_dict()
resource_dict.update(build_synapse_client_resource_tree(self.hs))
return resource_dict

def prepare(self, reactor, clock, hs):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.user_pass = "pass"
self.user = self.register_user("test", self.user_pass)
self.device_id = "dev1"
Expand Down Expand Up @@ -229,7 +233,7 @@ def delete_devices(self, expected_response: int, body: JsonDict) -> FakeChannel:

return channel

def test_ui_auth(self):
def test_ui_auth(self) -> None:
"""
Test user interactive authentication outside of registration.
"""
Expand Down Expand Up @@ -259,7 +263,7 @@ def test_ui_auth(self):
},
)

def test_grandfathered_identifier(self):
def test_grandfathered_identifier(self) -> None:
"""Check behaviour without "identifier" dict

Synapse used to require clients to submit a "user" field for m.login.password
Expand All @@ -286,7 +290,7 @@ def test_grandfathered_identifier(self):
},
)

def test_can_change_body(self):
def test_can_change_body(self) -> None:
"""
The client dict can be modified during the user interactive authentication session.

Expand Down Expand Up @@ -325,7 +329,7 @@ def test_can_change_body(self):
},
)

def test_cannot_change_uri(self):
def test_cannot_change_uri(self) -> None:
"""
The initial requested URI cannot be modified during the user interactive authentication session.
"""
Expand Down Expand Up @@ -362,7 +366,7 @@ def test_cannot_change_uri(self):
)

@unittest.override_config({"ui_auth": {"session_timeout": "5s"}})
def test_can_reuse_session(self):
def test_can_reuse_session(self) -> None:
"""
The session can be reused if configured.

Expand Down Expand Up @@ -409,7 +413,7 @@ def test_can_reuse_session(self):

@skip_unless(HAS_OIDC, "requires OIDC")
@override_config({"oidc_config": TEST_OIDC_CONFIG})
def test_ui_auth_via_sso(self):
def test_ui_auth_via_sso(self) -> None:
"""Test a successful UI Auth flow via SSO

This includes:
Expand Down Expand Up @@ -452,7 +456,7 @@ def test_ui_auth_via_sso(self):

@skip_unless(HAS_OIDC, "requires OIDC")
@override_config({"oidc_config": TEST_OIDC_CONFIG})
def test_does_not_offer_password_for_sso_user(self):
def test_does_not_offer_password_for_sso_user(self) -> None:
login_resp = self.helper.login_via_oidc("username")
user_tok = login_resp["access_token"]
device_id = login_resp["device_id"]
Expand All @@ -464,7 +468,7 @@ def test_does_not_offer_password_for_sso_user(self):
flows = channel.json_body["flows"]
self.assertEqual(flows, [{"stages": ["m.login.sso"]}])

def test_does_not_offer_sso_for_password_user(self):
def test_does_not_offer_sso_for_password_user(self) -> None:
channel = self.delete_device(
self.user_tok, self.device_id, HTTPStatus.UNAUTHORIZED
)
Expand All @@ -474,7 +478,7 @@ def test_does_not_offer_sso_for_password_user(self):

@skip_unless(HAS_OIDC, "requires OIDC")
@override_config({"oidc_config": TEST_OIDC_CONFIG})
def test_offers_both_flows_for_upgraded_user(self):
def test_offers_both_flows_for_upgraded_user(self) -> None:
"""A user that had a password and then logged in with SSO should get both flows"""
login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart)
self.assertEqual(login_resp["user_id"], self.user)
Expand All @@ -491,7 +495,7 @@ def test_offers_both_flows_for_upgraded_user(self):

@skip_unless(HAS_OIDC, "requires OIDC")
@override_config({"oidc_config": TEST_OIDC_CONFIG})
def test_ui_auth_fails_for_incorrect_sso_user(self):
def test_ui_auth_fails_for_incorrect_sso_user(self) -> None:
"""If the user tries to authenticate with the wrong SSO user, they get an error"""
# log the user in
login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart)
Expand Down Expand Up @@ -534,7 +538,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
]
hijack_auth = False

def prepare(self, reactor, clock, hs):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.user_pass = "pass"
self.user = self.register_user("test", self.user_pass)

Expand All @@ -548,7 +552,7 @@ def use_refresh_token(self, refresh_token: str) -> FakeChannel:
{"refresh_token": refresh_token},
)

def is_access_token_valid(self, access_token) -> bool:
def is_access_token_valid(self, access_token: str) -> bool:
"""
Checks whether an access token is valid, returning whether it is or not.
"""
Expand All @@ -561,7 +565,7 @@ def is_access_token_valid(self, access_token) -> bool:

return code == HTTPStatus.OK

def test_login_issue_refresh_token(self):
def test_login_issue_refresh_token(self) -> None:
"""
A login response should include a refresh_token only if asked.
"""
Expand Down Expand Up @@ -591,7 +595,7 @@ def test_login_issue_refresh_token(self):
self.assertIn("refresh_token", login_with_refresh.json_body)
self.assertIn("expires_in_ms", login_with_refresh.json_body)

def test_register_issue_refresh_token(self):
def test_register_issue_refresh_token(self) -> None:
"""
A register response should include a refresh_token only if asked.
"""
Expand Down Expand Up @@ -627,7 +631,7 @@ def test_register_issue_refresh_token(self):
self.assertIn("refresh_token", register_with_refresh.json_body)
self.assertIn("expires_in_ms", register_with_refresh.json_body)

def test_token_refresh(self):
def test_token_refresh(self) -> None:
"""
A refresh token can be used to issue a new access token.
"""
Expand Down Expand Up @@ -665,7 +669,7 @@ def test_token_refresh(self):
)

@override_config({"refreshable_access_token_lifetime": "1m"})
def test_refreshable_access_token_expiration(self):
def test_refreshable_access_token_expiration(self) -> None:
"""
The access token should have some time as specified in the config.
"""
Expand Down Expand Up @@ -722,7 +726,9 @@ def test_refreshable_access_token_expiration(self):
"nonrefreshable_access_token_lifetime": "10m",
}
)
def test_different_expiry_for_refreshable_and_nonrefreshable_access_tokens(self):
def test_different_expiry_for_refreshable_and_nonrefreshable_access_tokens(
self,
) -> None:
"""
Tests that the expiry times for refreshable and non-refreshable access
tokens can be different.
Expand Down Expand Up @@ -782,7 +788,7 @@ def test_different_expiry_for_refreshable_and_nonrefreshable_access_tokens(self)
@override_config(
{"refreshable_access_token_lifetime": "1m", "refresh_token_lifetime": "2m"}
)
def test_refresh_token_expiry(self):
def test_refresh_token_expiry(self) -> None:
"""
The refresh token can be configured to have a limited lifetime.
When that lifetime has ended, the refresh token can no longer be used to
Expand Down Expand Up @@ -834,7 +840,7 @@ def test_refresh_token_expiry(self):
"session_lifetime": "3m",
}
)
def test_ultimate_session_expiry(self):
def test_ultimate_session_expiry(self) -> None:
"""
The session can be configured to have an ultimate, limited lifetime.
"""
Expand Down Expand Up @@ -882,7 +888,7 @@ def test_ultimate_session_expiry(self):
refresh_response.code, HTTPStatus.FORBIDDEN, refresh_response.result
)

def test_refresh_token_invalidation(self):
def test_refresh_token_invalidation(self) -> None:
"""Refresh tokens are invalidated after first use of the next token.

A refresh token is considered invalid if:
Expand Down Expand Up @@ -987,7 +993,7 @@ def test_refresh_token_invalidation(self):
fifth_refresh_response.code, HTTPStatus.OK, fifth_refresh_response.result
)

def test_many_token_refresh(self):
def test_many_token_refresh(self) -> None:
"""
If a refresh is performed many times during a session, there shouldn't be
extra 'cruft' built up over time.
Expand Down
Loading