Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Make auth & transactions more testable (#3499)
Browse files Browse the repository at this point in the history
  • Loading branch information
hawkowl authored Jul 13, 2018
1 parent 2aba1f5 commit 33b60c0
Show file tree
Hide file tree
Showing 10 changed files with 97 additions and 98 deletions.
Empty file added changelog.d/3499.misc
Empty file.
124 changes: 62 additions & 62 deletions synapse/api/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def get_user_by_req(self, request, allow_guest=False, rights="access"):
synapse.types.create_requester(user_id, app_service=app_service)
)

access_token = get_access_token_from_request(
access_token = self.get_access_token_from_request(
request, self.TOKEN_NOT_FOUND_HTTP_STATUS
)

Expand Down Expand Up @@ -239,7 +239,7 @@ def get_user_by_req(self, request, allow_guest=False, rights="access"):
@defer.inlineCallbacks
def _get_appservice_user_id(self, request):
app_service = self.store.get_app_service_by_token(
get_access_token_from_request(
self.get_access_token_from_request(
request, self.TOKEN_NOT_FOUND_HTTP_STATUS
)
)
Expand Down Expand Up @@ -513,7 +513,7 @@ def _look_up_user_by_access_token(self, token):

def get_appservice_by_req(self, request):
try:
token = get_access_token_from_request(
token = self.get_access_token_from_request(
request, self.TOKEN_NOT_FOUND_HTTP_STATUS
)
service = self.store.get_app_service_by_token(token)
Expand Down Expand Up @@ -673,67 +673,67 @@ def check_can_change_room_list(self, room_id, user):
" edit its room list entry"
)

@staticmethod
def has_access_token(request):
"""Checks if the request has an access_token.
def has_access_token(request):
"""Checks if the request has an access_token.
Returns:
bool: False if no access_token was given, True otherwise.
"""
query_params = request.args.get("access_token")
auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
return bool(query_params) or bool(auth_headers)

Returns:
bool: False if no access_token was given, True otherwise.
"""
query_params = request.args.get("access_token")
auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
return bool(query_params) or bool(auth_headers)


def get_access_token_from_request(request, token_not_found_http_status=401):
"""Extracts the access_token from the request.
Args:
request: The http request.
token_not_found_http_status(int): The HTTP status code to set in the
AuthError if the token isn't found. This is used in some of the
legacy APIs to change the status code to 403 from the default of
401 since some of the old clients depended on auth errors returning
403.
Returns:
str: The access_token
Raises:
AuthError: If there isn't an access_token in the request.
"""
@staticmethod
def get_access_token_from_request(request, token_not_found_http_status=401):
"""Extracts the access_token from the request.
auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
query_params = request.args.get(b"access_token")
if auth_headers:
# Try the get the access_token from a "Authorization: Bearer"
# header
if query_params is not None:
raise AuthError(
token_not_found_http_status,
"Mixing Authorization headers and access_token query parameters.",
errcode=Codes.MISSING_TOKEN,
)
if len(auth_headers) > 1:
raise AuthError(
token_not_found_http_status,
"Too many Authorization headers.",
errcode=Codes.MISSING_TOKEN,
)
parts = auth_headers[0].split(" ")
if parts[0] == "Bearer" and len(parts) == 2:
return parts[1]
Args:
request: The http request.
token_not_found_http_status(int): The HTTP status code to set in the
AuthError if the token isn't found. This is used in some of the
legacy APIs to change the status code to 403 from the default of
401 since some of the old clients depended on auth errors returning
403.
Returns:
str: The access_token
Raises:
AuthError: If there isn't an access_token in the request.
"""

auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
query_params = request.args.get(b"access_token")
if auth_headers:
# Try the get the access_token from a "Authorization: Bearer"
# header
if query_params is not None:
raise AuthError(
token_not_found_http_status,
"Mixing Authorization headers and access_token query parameters.",
errcode=Codes.MISSING_TOKEN,
)
if len(auth_headers) > 1:
raise AuthError(
token_not_found_http_status,
"Too many Authorization headers.",
errcode=Codes.MISSING_TOKEN,
)
parts = auth_headers[0].split(" ")
if parts[0] == "Bearer" and len(parts) == 2:
return parts[1]
else:
raise AuthError(
token_not_found_http_status,
"Invalid Authorization header.",
errcode=Codes.MISSING_TOKEN,
)
else:
raise AuthError(
token_not_found_http_status,
"Invalid Authorization header.",
errcode=Codes.MISSING_TOKEN,
)
else:
# Try to get the access_token from the query params.
if not query_params:
raise AuthError(
token_not_found_http_status,
"Missing access token.",
errcode=Codes.MISSING_TOKEN
)
# Try to get the access_token from the query params.
if not query_params:
raise AuthError(
token_not_found_http_status,
"Missing access token.",
errcode=Codes.MISSING_TOKEN
)

return query_params[0]
return query_params[0]
45 changes: 22 additions & 23 deletions synapse/rest/client/transactions.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,45 +17,44 @@
to ensure idempotency when performing PUTs using the REST API."""
import logging

from synapse.api.auth import get_access_token_from_request
from synapse.util.async import ObservableDeferred
from synapse.util.logcontext import make_deferred_yieldable, run_in_background

logger = logging.getLogger(__name__)


def get_transaction_key(request):
"""A helper function which returns a transaction key that can be used
with TransactionCache for idempotent requests.
Idempotency is based on the returned key being the same for separate
requests to the same endpoint. The key is formed from the HTTP request
path and the access_token for the requesting user.
Args:
request (twisted.web.http.Request): The incoming request. Must
contain an access_token.
Returns:
str: A transaction key
"""
token = get_access_token_from_request(request)
return request.path + "/" + token


CLEANUP_PERIOD_MS = 1000 * 60 * 30 # 30 mins


class HttpTransactionCache(object):

def __init__(self, clock):
self.clock = clock
def __init__(self, hs):
self.hs = hs
self.auth = self.hs.get_auth()
self.clock = self.hs.get_clock()
self.transactions = {
# $txn_key: (ObservableDeferred<(res_code, res_json_body)>, timestamp)
}
# Try to clean entries every 30 mins. This means entries will exist
# for at *LEAST* 30 mins, and at *MOST* 60 mins.
self.cleaner = self.clock.looping_call(self._cleanup, CLEANUP_PERIOD_MS)

def _get_transaction_key(self, request):
"""A helper function which returns a transaction key that can be used
with TransactionCache for idempotent requests.
Idempotency is based on the returned key being the same for separate
requests to the same endpoint. The key is formed from the HTTP request
path and the access_token for the requesting user.
Args:
request (twisted.web.http.Request): The incoming request. Must
contain an access_token.
Returns:
str: A transaction key
"""
token = self.auth.get_access_token_from_request(request)
return request.path + "/" + token

def fetch_or_execute_request(self, request, fn, *args, **kwargs):
"""A helper function for fetch_or_execute which extracts
a transaction key from the given request.
Expand All @@ -64,7 +63,7 @@ def fetch_or_execute_request(self, request, fn, *args, **kwargs):
fetch_or_execute
"""
return self.fetch_or_execute(
get_transaction_key(request), fn, *args, **kwargs
self._get_transaction_key(request), fn, *args, **kwargs
)

def fetch_or_execute(self, txn_key, fn, *args, **kwargs):
Expand Down
2 changes: 1 addition & 1 deletion synapse/rest/client/v1/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,4 +62,4 @@ def __init__(self, hs):
self.hs = hs
self.builder_factory = hs.get_event_builder_factory()
self.auth = hs.get_auth()
self.txns = HttpTransactionCache(hs.get_clock())
self.txns = HttpTransactionCache(hs)
3 changes: 1 addition & 2 deletions synapse/rest/client/v1/logout.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

from twisted.internet import defer

from synapse.api.auth import get_access_token_from_request
from synapse.api.errors import AuthError

from .base import ClientV1RestServlet, client_path_patterns
Expand Down Expand Up @@ -51,7 +50,7 @@ def on_POST(self, request):
if requester.device_id is None:
# the acccess token wasn't associated with a device.
# Just delete the access token
access_token = get_access_token_from_request(request)
access_token = self._auth.get_access_token_from_request(request)
yield self._auth_handler.delete_access_token(access_token)
else:
yield self._device_handler.delete_device(
Expand Down
6 changes: 3 additions & 3 deletions synapse/rest/client/v1/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from twisted.internet import defer

import synapse.util.stringutils as stringutils
from synapse.api.auth import get_access_token_from_request
from synapse.api.constants import LoginType
from synapse.api.errors import Codes, SynapseError
from synapse.http.servlet import parse_json_object_from_request
Expand Down Expand Up @@ -67,6 +66,7 @@ def __init__(self, hs):
# TODO: persistent storage
self.sessions = {}
self.enable_registration = hs.config.enable_registration
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
self.handlers = hs.get_handlers()

Expand Down Expand Up @@ -310,7 +310,7 @@ def _do_password(self, request, register_json, session):

@defer.inlineCallbacks
def _do_app_service(self, request, register_json, session):
as_token = get_access_token_from_request(request)
as_token = self.auth.get_access_token_from_request(request)

if "user" not in register_json:
raise SynapseError(400, "Expected 'user' key.")
Expand Down Expand Up @@ -400,7 +400,7 @@ def __init__(self, hs):
def on_POST(self, request):
user_json = parse_json_object_from_request(request)

access_token = get_access_token_from_request(request)
access_token = self.auth.get_access_token_from_request(request)
app_service = self.store.get_app_service_by_token(
access_token
)
Expand Down
3 changes: 1 addition & 2 deletions synapse/rest/client/v2_alpha/account.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

from twisted.internet import defer

from synapse.api.auth import has_access_token
from synapse.api.constants import LoginType
from synapse.api.errors import Codes, SynapseError
from synapse.http.servlet import (
Expand Down Expand Up @@ -130,7 +129,7 @@ def on_POST(self, request):
#
# In the second case, we require a password to confirm their identity.

if has_access_token(request):
if self.auth.has_access_token(request):
requester = yield self.auth.get_user_by_req(request)
params = yield self.auth_handler.validate_user_via_ui_auth(
requester, body, self.hs.get_ip_from_request(request),
Expand Down
5 changes: 2 additions & 3 deletions synapse/rest/client/v2_alpha/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@

import synapse
import synapse.types
from synapse.api.auth import get_access_token_from_request, has_access_token
from synapse.api.constants import LoginType
from synapse.api.errors import Codes, SynapseError, UnrecognizedRequestError
from synapse.http.servlet import (
Expand Down Expand Up @@ -224,7 +223,7 @@ def on_POST(self, request):
desired_username = body['username']

appservice = None
if has_access_token(request):
if self.auth.has_access_token(request):
appservice = yield self.auth.get_appservice_by_req(request)

# fork off as soon as possible for ASes and shared secret auth which
Expand All @@ -242,7 +241,7 @@ def on_POST(self, request):
# because the IRC bridges rely on being able to register stupid
# IDs.

access_token = get_access_token_from_request(request)
access_token = self.auth.get_access_token_from_request(request)

if isinstance(desired_username, string_types):
result = yield self._do_appservice_registration(
Expand Down
2 changes: 1 addition & 1 deletion synapse/rest/client/v2_alpha/sendtodevice.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __init__(self, hs):
super(SendToDeviceRestServlet, self).__init__()
self.hs = hs
self.auth = hs.get_auth()
self.txns = HttpTransactionCache(hs.get_clock())
self.txns = HttpTransactionCache(hs)
self.device_message_handler = hs.get_device_message_handler()

def on_PUT(self, request, message_type, txn_id):
Expand Down
5 changes: 4 additions & 1 deletion tests/rest/client/test_transactions.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@ class HttpTransactionCacheTestCase(unittest.TestCase):

def setUp(self):
self.clock = MockClock()
self.cache = HttpTransactionCache(self.clock)
self.hs = Mock()
self.hs.get_clock = Mock(return_value=self.clock)
self.hs.get_auth = Mock()
self.cache = HttpTransactionCache(self.hs)

self.mock_http_response = (200, "GOOD JOB!")
self.mock_key = "foo"
Expand Down

0 comments on commit 33b60c0

Please sign in to comment.