Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Decouple synapse.api.auth_blocking.AuthBlocking from synapse.api.auth.Auth. #13021

Merged
merged 5 commits into from
Jun 14, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/13021.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Decouple `synapse.api.auth_blocking.AuthBlocking` from `synapse.api.auth.Auth`.
14 changes: 0 additions & 14 deletions synapse/api/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from twisted.web.server import Request

from synapse import event_auth
from synapse.api.auth_blocking import AuthBlocking
from synapse.api.constants import EventTypes, HistoryVisibility, Membership
from synapse.api.errors import (
AuthError,
Expand Down Expand Up @@ -67,8 +66,6 @@ def __init__(self, hs: "HomeServer"):
10000, "token_cache"
)

self._auth_blocking = AuthBlocking(self.hs)

self._track_appservice_user_ips = hs.config.appservice.track_appservice_user_ips
self._track_puppeted_user_ips = hs.config.api.track_puppeted_user_ips
self._macaroon_secret_key = hs.config.key.macaroon_secret_key
Expand Down Expand Up @@ -711,14 +708,3 @@ async def check_user_in_room_or_world_readable(
"User %s not in room %s, and room previews are disabled"
% (user_id, room_id),
)

async def check_auth_blocking(
self,
user_id: Optional[str] = None,
threepid: Optional[dict] = None,
user_type: Optional[str] = None,
requester: Optional[Requester] = None,
) -> None:
await self._auth_blocking.check_auth_blocking(
user_id=user_id, threepid=threepid, user_type=user_type, requester=requester
)
5 changes: 3 additions & 2 deletions synapse/handlers/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ class AuthHandler:
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastores().main
self.auth = hs.get_auth()
self.auth_blocking = hs.get_auth_blocking()
self.clock = hs.get_clock()
self.checkers: Dict[str, UserInteractiveAuthChecker] = {}
for auth_checker_class in INTERACTIVE_AUTH_CHECKERS:
Expand Down Expand Up @@ -985,7 +986,7 @@ async def create_access_token_for_user_id(
not is_appservice_ghost
or self.hs.config.appservice.track_appservice_user_ips
):
await self.auth.check_auth_blocking(user_id)
await self.auth_blocking.check_auth_blocking(user_id)

access_token = self.generate_access_token(target_user_id_obj)
await self.store.add_access_token_to_user(
Expand Down Expand Up @@ -1439,7 +1440,7 @@ async def validate_short_term_login_token(
except Exception:
raise AuthError(403, "Invalid login token", errcode=Codes.FORBIDDEN)

await self.auth.check_auth_blocking(res.user_id)
await self.auth_blocking.check_auth_blocking(res.user_id)
return res

async def delete_access_token(self, access_token: str) -> None:
Expand Down
4 changes: 2 additions & 2 deletions synapse/handlers/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,7 @@ async def _expire_event(self, event_id: str) -> None:
class EventCreationHandler:
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.auth = hs.get_auth()
self.auth_blocking = hs.get_auth_blocking()
self._event_auth_handler = hs.get_event_auth_handler()
self.store = hs.get_datastores().main
self._storage_controllers = hs.get_storage_controllers()
Expand Down Expand Up @@ -605,7 +605,7 @@ async def create_event(
Returns:
Tuple of created event, Context
"""
await self.auth.check_auth_blocking(requester=requester)
await self.auth_blocking.check_auth_blocking(requester=requester)

if event_dict["type"] == EventTypes.Create and event_dict["state_key"] == "":
room_version_id = event_dict["content"]["room_version"]
Expand Down
3 changes: 2 additions & 1 deletion synapse/handlers/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def __init__(self, hs: "HomeServer"):
self.clock = hs.get_clock()
self.hs = hs
self.auth = hs.get_auth()
self.auth_blocking = hs.get_auth_blocking()
self._auth_handler = hs.get_auth_handler()
self.profile_handler = hs.get_profile_handler()
self.user_directory_handler = hs.get_user_directory_handler()
Expand Down Expand Up @@ -276,7 +277,7 @@ async def register_user(

# do not check_auth_blocking if the call is coming through the Admin API
if not by_admin:
await self.auth.check_auth_blocking(threepid=threepid)
await self.auth_blocking.check_auth_blocking(threepid=threepid)

if localpart is not None:
await self.check_username(localpart, guest_access_token=guest_access_token)
Expand Down
3 changes: 2 additions & 1 deletion synapse/handlers/room.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastores().main
self._storage_controllers = hs.get_storage_controllers()
self.auth = hs.get_auth()
self.auth_blocking = hs.get_auth_blocking()
self.clock = hs.get_clock()
self.hs = hs
self.spam_checker = hs.get_spam_checker()
Expand Down Expand Up @@ -707,7 +708,7 @@ async def create_room(
"""
user_id = requester.user.to_string()

await self.auth.check_auth_blocking(requester=requester)
await self.auth_blocking.check_auth_blocking(requester=requester)

if (
self._server_notices_mxid is not None
Expand Down
4 changes: 2 additions & 2 deletions synapse/handlers/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ def __init__(self, hs: "HomeServer"):
self.event_sources = hs.get_event_sources()
self.clock = hs.get_clock()
self.state = hs.get_state_handler()
self.auth = hs.get_auth()
self.auth_blocking = hs.get_auth_blocking()
self._storage_controllers = hs.get_storage_controllers()
self._state_storage_controller = self._storage_controllers.state

Expand Down Expand Up @@ -280,7 +280,7 @@ async def wait_for_sync_for_user(
# not been exceeded (if not part of the group by this point, almost certain
# auth_blocking will occur)
user_id = sync_config.user.to_string()
await self.auth.check_auth_blocking(requester=requester)
await self.auth_blocking.check_auth_blocking(requester=requester)

res = await self.response_cache.wrap(
sync_config.request_key,
Expand Down
5 changes: 5 additions & 0 deletions synapse/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from twisted.web.resource import Resource

from synapse.api.auth import Auth
from synapse.api.auth_blocking import AuthBlocking
from synapse.api.filtering import Filtering
from synapse.api.ratelimiting import Ratelimiter, RequestRatelimiter
from synapse.appservice.api import ApplicationServiceApi
Expand Down Expand Up @@ -379,6 +380,10 @@ def get_notifier(self) -> Notifier:
def get_auth(self) -> Auth:
return Auth(self)

@cache_in_self
def get_auth_blocking(self) -> AuthBlocking:
return AuthBlocking(self)

@cache_in_self
def get_http_client_context_factory(self) -> IPolicyForHTTPS:
if self.config.tls.use_insecure_ssl_client_just_for_testing_do_not_use:
Expand Down
4 changes: 2 additions & 2 deletions synapse/server_notices/resource_limits_server_notices.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __init__(self, hs: "HomeServer"):
self._server_notices_manager = hs.get_server_notices_manager()
self._store = hs.get_datastores().main
self._storage_controllers = hs.get_storage_controllers()
self._auth = hs.get_auth()
self._auth_blocking = hs.get_auth_blocking()
self._config = hs.config
self._resouce_limited = False
self._account_data_handler = hs.get_account_data_handler()
Expand Down Expand Up @@ -91,7 +91,7 @@ async def maybe_send_server_notice_to_user(self, user_id: str) -> None:
# Normally should always pass in user_id to check_auth_blocking
# if you have it, but in this case are checking what would happen
# to other users if they were to arrive.
await self._auth.check_auth_blocking()
await self._auth_blocking.check_auth_blocking()
except ResourceLimitError as e:
limit_msg = e.msg
limit_type = e.limit_type
Expand Down
41 changes: 27 additions & 14 deletions tests/api/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from twisted.test.proto_helpers import MemoryReactor

from synapse.api.auth import Auth
from synapse.api.auth_blocking import AuthBlocking
from synapse.api.constants import UserTypes
from synapse.api.errors import (
AuthError,
Expand Down Expand Up @@ -46,6 +47,7 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer):
hs.datastores.main = self.store
hs.get_auth_handler().store = self.store
self.auth = Auth(hs)
self.auth_blocking = AuthBlocking(hs)

# AuthBlocking reads from the hs' config on initialization. We need to
# modify its config instead of the hs'
Expand Down Expand Up @@ -362,36 +364,41 @@ def test_blocking_mau(self):
small_number_of_users = 1

# Ensure no error thrown
self.get_success(self.auth.check_auth_blocking())
self.get_success(self.auth_blocking.check_auth_blocking())

self.auth_blocking._limit_usage_by_mau = True

self.store.get_monthly_active_count = simple_async_mock(lots_of_users)

e = self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError)
e = self.get_failure(
self.auth_blocking.check_auth_blocking(), ResourceLimitError
)
self.assertEqual(e.value.admin_contact, self.hs.config.server.admin_contact)
self.assertEqual(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
self.assertEqual(e.value.code, 403)

# Ensure does not throw an error
self.store.get_monthly_active_count = simple_async_mock(small_number_of_users)
self.get_success(self.auth.check_auth_blocking())
self.get_success(self.auth_blocking.check_auth_blocking())

def test_blocking_mau__depending_on_user_type(self):
self.auth_blocking._max_mau_value = 50
self.auth_blocking._limit_usage_by_mau = True

self.store.get_monthly_active_count = simple_async_mock(100)
# Support users allowed
self.get_success(self.auth.check_auth_blocking(user_type=UserTypes.SUPPORT))
self.get_success(
self.auth_blocking.check_auth_blocking(user_type=UserTypes.SUPPORT)
)
self.store.get_monthly_active_count = simple_async_mock(100)
# Bots not allowed
self.get_failure(
self.auth.check_auth_blocking(user_type=UserTypes.BOT), ResourceLimitError
self.auth_blocking.check_auth_blocking(user_type=UserTypes.BOT),
ResourceLimitError,
)
self.store.get_monthly_active_count = simple_async_mock(100)
# Real users not allowed
self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError)
self.get_failure(self.auth_blocking.check_auth_blocking(), ResourceLimitError)

def test_blocking_mau__appservice_requester_allowed_when_not_tracking_ips(self):
self.auth_blocking._max_mau_value = 50
Expand Down Expand Up @@ -419,7 +426,7 @@ def test_blocking_mau__appservice_requester_allowed_when_not_tracking_ips(self):
app_service=appservice,
authenticated_entity="@appservice:server",
)
self.get_success(self.auth.check_auth_blocking(requester=requester))
self.get_success(self.auth_blocking.check_auth_blocking(requester=requester))

def test_blocking_mau__appservice_requester_disallowed_when_tracking_ips(self):
self.auth_blocking._max_mau_value = 50
Expand Down Expand Up @@ -448,7 +455,8 @@ def test_blocking_mau__appservice_requester_disallowed_when_tracking_ips(self):
authenticated_entity="@appservice:server",
)
self.get_failure(
self.auth.check_auth_blocking(requester=requester), ResourceLimitError
self.auth_blocking.check_auth_blocking(requester=requester),
ResourceLimitError,
)

def test_reserved_threepid(self):
Expand All @@ -459,18 +467,21 @@ def test_reserved_threepid(self):
unknown_threepid = {"medium": "email", "address": "unreserved@server.com"}
self.auth_blocking._mau_limits_reserved_threepids = [threepid]

self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError)
self.get_failure(self.auth_blocking.check_auth_blocking(), ResourceLimitError)

self.get_failure(
self.auth.check_auth_blocking(threepid=unknown_threepid), ResourceLimitError
self.auth_blocking.check_auth_blocking(threepid=unknown_threepid),
ResourceLimitError,
)

self.get_success(self.auth.check_auth_blocking(threepid=threepid))
self.get_success(self.auth_blocking.check_auth_blocking(threepid=threepid))

def test_hs_disabled(self):
self.auth_blocking._hs_disabled = True
self.auth_blocking._hs_disabled_message = "Reason for being disabled"
e = self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError)
e = self.get_failure(
self.auth_blocking.check_auth_blocking(), ResourceLimitError
)
self.assertEqual(e.value.admin_contact, self.hs.config.server.admin_contact)
self.assertEqual(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
self.assertEqual(e.value.code, 403)
Expand All @@ -485,7 +496,9 @@ def test_hs_disabled_no_server_notices_user(self):

self.auth_blocking._hs_disabled = True
self.auth_blocking._hs_disabled_message = "Reason for being disabled"
e = self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError)
e = self.get_failure(
self.auth_blocking.check_auth_blocking(), ResourceLimitError
)
self.assertEqual(e.value.admin_contact, self.hs.config.server.admin_contact)
self.assertEqual(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
self.assertEqual(e.value.code, 403)
Expand All @@ -495,4 +508,4 @@ def test_server_notices_mxid_special_cased(self):
user = "@user:server"
self.auth_blocking._server_notices_mxid = user
self.auth_blocking._hs_disabled_message = "Reason for being disabled"
self.get_success(self.auth.check_auth_blocking(user))
self.get_success(self.auth_blocking.check_auth_blocking(user))
2 changes: 1 addition & 1 deletion tests/handlers/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
# MAU tests
# AuthBlocking reads from the hs' config on initialization. We need to
# modify its config instead of the hs'
self.auth_blocking = hs.get_auth()._auth_blocking
self.auth_blocking = hs.get_auth_blocking()
self.auth_blocking._max_mau_value = 50

self.small_number_of_users = 1
Expand Down
2 changes: 1 addition & 1 deletion tests/handlers/test_register.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,7 +699,7 @@ async def get_or_create_user(
"""
if localpart is None:
raise SynapseError(400, "Request must include user id")
await self.hs.get_auth().check_auth_blocking()
await self.hs.get_auth_blocking().check_auth_blocking()
need_register = True

try:
Expand Down