From f3e6a5754e89cda30fa88ef8b9dfa31e1697a688 Mon Sep 17 00:00:00 2001 From: Miguel Grinberg Date: Mon, 16 Nov 2020 10:10:38 +0000 Subject: [PATCH] Allow error response to return a 200 status code (Fixes #114) --- flask_httpauth.py | 5 ++-- tests/test_error_responses.py | 47 +++++++++++++++++++++++++++++++++++ 2 files changed, 50 insertions(+), 2 deletions(-) create mode 100644 tests/test_error_responses.py diff --git a/flask_httpauth.py b/flask_httpauth.py index 9183feb..4b8f387 100644 --- a/flask_httpauth.py +++ b/flask_httpauth.py @@ -12,7 +12,7 @@ from functools import wraps from hashlib import md5 from random import Random, SystemRandom -from flask import request, make_response, session, g +from flask import request, make_response, session, g, Response from werkzeug.datastructures import Authorization from werkzeug.security import safe_str_cmp @@ -49,8 +49,9 @@ def error_handler(self, f): @wraps(f) def decorated(*args, **kwargs): res = f(*args, **kwargs) + check_status_code = not isinstance(res, (tuple, Response)) res = make_response(res) - if res.status_code == 200: + if check_status_code and res.status_code == 200: # if user didn't set status code, use 401 res.status_code = 401 if 'WWW-Authenticate' not in res.headers.keys(): diff --git a/tests/test_error_responses.py b/tests/test_error_responses.py new file mode 100644 index 0000000..ad6e64c --- /dev/null +++ b/tests/test_error_responses.py @@ -0,0 +1,47 @@ +import unittest +import base64 +from flask import Flask, Response +from flask_httpauth import HTTPBasicAuth + + +class HTTPAuthTestCase(unittest.TestCase): + responses = [ + ['error', 401], + [('error', 403), 403], + [('error', 200), 200], + [Response('error'), 200], + [Response('error', 403), 403], + ] + + def setUp(self): + app = Flask(__name__) + app.config['SECRET_KEY'] = 'my secret' + + basic_verify_auth = HTTPBasicAuth() + + @basic_verify_auth.verify_password + def basic_verify_auth_verify_password(username, password): + return False + + @basic_verify_auth.error_handler + def error_handler(): + self.assertIsNone(basic_verify_auth.current_user()) + return self.error_response + + @app.route('/') + @basic_verify_auth.login_required + def index(): + return 'index' + + self.app = app + self.basic_verify_auth = basic_verify_auth + self.client = app.test_client() + + def test_default_status_code(self): + creds = base64.b64encode(b'foo:bar').decode('utf-8') + + for r in self.responses: + self.error_response = r[0] + response = self.client.get( + '/', headers={'Authorization': 'Basic ' + creds}) + self.assertEqual(response.status_code, r[1])