diff --git a/flask_security/views.py b/flask_security/views.py index 24f5fbea..8c7ff4b7 100644 --- a/flask_security/views.py +++ b/flask_security/views.py @@ -1160,10 +1160,21 @@ def recover_username(): if user: send_username_recovery_email(user) - do_flash(*get_message("USERNAME_RECOVERY_REQUEST", email=form.email.data)) + if not _security._want_json(request): + do_flash(*get_message("USERNAME_RECOVERY_REQUEST", email=form.email.data)) + elif request.method == "POST" and cv("RETURN_GENERIC_RESPONSES"): + rinfo = dict(email=dict()) + form_errors_munge(form, rinfo) + if not form.errors: + if not _security._want_json(request): + do_flash( + *get_message("USERNAME_RECOVERY_REQUEST", email=form.email.data) + ) - if _security._want_json(request): - return base_render_json(form, include_auth_token=True) + if _security._want_json(request): + return base_render_json(form, include_user=False) + + if form.validate_on_submit(): return redirect(url_for_security("login")) return _security.render_template( diff --git a/tests/conftest.py b/tests/conftest.py index 85c8852d..4cb5780d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -127,6 +127,8 @@ def app(request: pytest.FixtureRequest) -> SecurityFixture: # Make this hex_md5 for token tests app.config["SECURITY_HASHING_SCHEMES"] = ["hex_md5"] app.config["SECURITY_DEPRECATED_HASHING_SCHEMES"] = [] + # Enable username recovery for tests + app.config["SECURITY_USERNAME_RECOVERY"] = True for opt in [ "changeable", diff --git a/tests/test_recoverable.py b/tests/test_recoverable.py index 839cb0f4..c6bf729a 100644 --- a/tests/test_recoverable.py +++ b/tests/test_recoverable.py @@ -836,6 +836,14 @@ def on_email_sent(app, **kwargs): email = app.mail.outbox[1] assert "Your username is: joe" in email.body + # Test JSON responses + response = clients.post( + "/recover-username", + json=dict(email="joe@lp.com"), + headers={"Content-Type": "application/json"}, + ) + assert response.status_code == 200 + def test_username_recovery_invalid_email(app, clients): response = clients.post( @@ -844,3 +852,68 @@ def test_username_recovery_invalid_email(app, clients): assert not app.mail.outbox assert response.status_code == 200 + + # Test JSON responses + response = clients.post( + "/recover-username", + json=dict(email="bogus@lp.com"), + headers={"Content-Type": "application/json"}, + ) + assert response.status_code == 400 + + +@pytest.mark.settings(return_generic_responses=True) +def test_username_recovery_generic_responses(app, clients, get_message): + recorded_recovery_sent = [] + + @username_recovery_email_sent.connect_via(app) + def on_email_sent(app, **kwargs): + recorded_recovery_sent.append(kwargs["user"]) + + # Test with valid email + with capture_flashes() as flashes: + response = clients.post( + "/recover-username", + data=dict(email="joe@lp.com"), + follow_redirects=True, + ) + assert len(flashes) == 1 + assert get_message("USERNAME_RECOVERY_REQUEST") == flashes[0]["message"].encode( + "utf-8" + ) + assert len(recorded_recovery_sent) == 1 + assert len(app.mail.outbox) == 1 + assert response.status_code == 200 + + # Test with non-existant email (should still return 200) + with capture_flashes() as flashes: + response = clients.post( + "/recover-username", + data=dict(email="bogus@lp.com"), + follow_redirects=True, + ) + assert len(flashes) == 1 + assert get_message("USERNAME_RECOVERY_REQUEST") == flashes[0]["message"].encode( + "utf-8" + ) + # Validate no email was sent (there should only be one from the previous test) + assert len(recorded_recovery_sent) == 1 + assert len(app.mail.outbox) == 1 + assert response.status_code == 200 + + # Test JSON responses - valid email + response = clients.post( + "/recover-username", + json=dict(email="joe@lp.com"), + headers={"Content-Type": "application/json"}, + ) + assert response.status_code == 200 + + # Test JSON responses - invalid email + response = clients.post( + "/recover-username", + json=dict(email="bogus@lp.com"), + headers={"Content-Type": "application/json"}, + ) + assert response.status_code == 200 + assert not any(e in response.json["response"].keys() for e in ["error", "errors"])