diff --git a/docs/index.rst b/docs/index.rst index 5bc5a76..8fa09cd 100755 --- a/docs/index.rst +++ b/docs/index.rst @@ -197,6 +197,7 @@ API Documentation The ``realm`` argument can be used to provide an application defined realm with the ``WWW-Authenticate`` header. + .. method:: get_password(password_callback) This callback function will be called by the framework to obtain the password for a given user. Example:: @@ -331,7 +332,7 @@ API Documentation This class handles HTTP authentication with custom schemes for Flask routes. - .. method:: __init__(scheme='Bearer', realm=None) + .. method:: __init__(scheme='Bearer', realm=None, header=None) Create a token authentication object. @@ -339,6 +340,8 @@ API Documentation The ``realm`` argument can be used to provide an application defined realm with the ``WWW-Authenticate`` header. + The ``header`` optional argument, defaults to ``Authorization``. It can be used to define a custom token header. + .. method:: verify_token(verify_token_callback) This callback function will be called by the framework to verify that the credentials sent by the client with the ``Authorization`` header are valid. The callback function takes one argument, the username and the password and must return ``True`` or ``False``. Example usage:: diff --git a/flask_httpauth.py b/flask_httpauth.py index 10f5576..623b598 100644 --- a/flask_httpauth.py +++ b/flask_httpauth.py @@ -19,9 +19,10 @@ class HTTPAuth(object): - def __init__(self, scheme=None, realm=None): + def __init__(self, scheme=None, realm=None, header=None): self.scheme = scheme self.realm = realm or "Authentication Required" + self.header = header self.get_password_callback = None self.get_user_roles_callback = None self.auth_error_callback = None @@ -61,18 +62,25 @@ def authenticate_header(self): return '{0} realm="{1}"'.format(self.scheme, self.realm) def get_auth(self): - auth = request.authorization - if auth is None and 'Authorization' in request.headers: - # Flask/Werkzeug do not recognize any authentication types - # other than Basic or Digest, so here we parse the header by - # hand - try: - auth_type, token = request.headers['Authorization'].split( - None, 1) - auth = Authorization(auth_type, {'token': token}) - except ValueError: - # The Authorization header is either empty or has no token - pass + auth = None + if self.header is None or self.header == 'Authorization': + auth = request.authorization + if auth is None and 'Authorization' in request.headers: + # Flask/Werkzeug do not recognize any authentication types + # other than Basic or Digest, so here we parse the header by + # hand + try: + auth_type, token = request.headers['Authorization'].split( + None, 1) + auth = Authorization(auth_type, {'token': token}) + except (ValueError, KeyError): + # The Authorization header is either empty or has no token + pass + elif self.header in request.headers: + # using a custom header, so the entire value of the header is + # assumed to be a token + auth = Authorization(self.scheme, + {'token': request.headers[self.header]}) # if the auth type does not match, we act as if there is no auth # this is better than failing directly, as it allows the callback @@ -302,8 +310,8 @@ def authenticate(self, auth, stored_password_or_ha1): class HTTPTokenAuth(HTTPAuth): - def __init__(self, scheme='Bearer', realm=None): - super(HTTPTokenAuth, self).__init__(scheme, realm) + def __init__(self, scheme='Bearer', realm=None, header=None): + super(HTTPTokenAuth, self).__init__(scheme, realm, header) self.verify_token_callback = None @@ -318,7 +326,6 @@ def authenticate(self, auth, stored_password): token = "" if self.verify_token_callback: return self.verify_token_callback(token) - return False class MultiAuth(object): diff --git a/tests/test_token.py b/tests/test_token.py index b9b5b05..767d456 100644 --- a/tests/test_token.py +++ b/tests/test_token.py @@ -1,3 +1,4 @@ +import base64 import unittest from flask import Flask from flask_httpauth import HTTPTokenAuth @@ -9,12 +10,19 @@ def setUp(self): app.config['SECRET_KEY'] = 'my secret' token_auth = HTTPTokenAuth('MyToken') + token_auth2 = HTTPTokenAuth('Token', realm='foo') + token_auth3 = HTTPTokenAuth(header='X-API-Key') @token_auth.verify_token def verify_token(token): if token == 'this-is-the-token!': return 'user' + @token_auth3.verify_token + def verify_token3(token): + if token == 'this-is-the-token!': + return 'user' + @token_auth.error_handler def error_handler(): return 'error', 401, {'WWW-Authenticate': 'MyToken realm="Foo"'} @@ -28,6 +36,16 @@ def index(): def token_auth_route(): return 'token_auth:' + token_auth.current_user() + @app.route('/protected2') + @token_auth2.login_required + def token_auth_route2(): + return 'token_auth2' + + @app.route('/protected3') + @token_auth3.login_required + def token_auth_route3(): + return 'token_auth3:' + token_auth3.current_user() + self.app = app self.token_auth = token_auth self.client = app.test_client() @@ -82,13 +100,6 @@ def test_token_auth_login_invalid_header(self): 'MyToken realm="Foo"') def test_token_auth_login_invalid_no_callback(self): - token_auth2 = HTTPTokenAuth('Token', realm='foo') - - @self.app.route('/protected2') - @token_auth2.login_required - def token_auth_route2(): - return 'token_auth2' - response = self.client.get( '/protected2', headers={'Authorization': 'Token this-is-the-token!'}) @@ -96,3 +107,31 @@ def token_auth_route2(): self.assertTrue('WWW-Authenticate' in response.headers) self.assertEqual(response.headers['WWW-Authenticate'], 'Token realm="foo"') + + def test_token_auth_custom_header_valid_token(self): + response = self.client.get( + '/protected3', headers={'X-API-Key': 'this-is-the-token!'}) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.data.decode('utf-8'), 'token_auth3:user') + + def test_token_auth_custom_header_invalid_token(self): + response = self.client.get( + '/protected3', headers={'X-API-Key': 'invalid-token-should-fail'}) + self.assertEqual(response.status_code, 401) + self.assertTrue('WWW-Authenticate' in response.headers) + + def test_token_auth_custom_header_invalid_header(self): + response = self.client.get( + '/protected3', headers={'API-Key': 'this-is-the-token!'}) + self.assertEqual(response.status_code, 401) + self.assertTrue('WWW-Authenticate' in response.headers) + self.assertEqual(response.headers['WWW-Authenticate'], + 'Bearer realm="Authentication Required"') + + def test_token_auth_header_precedence(self): + basic_creds = base64.b64encode(b'susan:bye').decode('utf-8') + response = self.client.get( + '/protected3', headers={'Authorization': 'Basic ' + basic_creds, + 'X-API-Key': 'this-is-the-token!'}) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.data.decode('utf-8'), 'token_auth3:user')