Skip to content

Commit

Permalink
fix(headless): Email verification by code & change email
Browse files Browse the repository at this point in the history
  • Loading branch information
pennersr committed Dec 1, 2024
1 parent 7a10879 commit 45ea40a
Show file tree
Hide file tree
Showing 9 changed files with 140 additions and 35 deletions.
6 changes: 5 additions & 1 deletion ChangeLog.rst
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
65.3.1 (unreleased)
*******************

- ...
Fixes
-----

- Headless: When using email verification by code, you could incorrectly
encounter a 409 when attempting to add a new email address while logged in.


65.3.0 (2024-11-30)
Expand Down
6 changes: 6 additions & 0 deletions allauth/account/internal/flows/email_verification_by_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,12 @@ def from_key(cls, key):
def key_expired(self):
return False

def confirm(self, request) -> Optional[EmailAddress]:
ret = super().confirm(request)
if ret:
clear_state(request)
return ret


def clear_state(request):
request.session.pop(EMAIL_VERIFICATION_CODE_SESSION_KEY, None)
Expand Down
16 changes: 15 additions & 1 deletion allauth/account/internal/flows/manage_email.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional
from typing import List, Optional

from django.contrib import messages
from django.http import HttpRequest
Expand Down Expand Up @@ -171,3 +171,17 @@ def assess_unique_email(email) -> Optional[bool]:
# to be unique. In this case, uniqueness takes precedence over
# enumeration prevention.
return False


def list_email_addresses(request, user) -> List[EmailAddress]:
addresses = list(EmailAddress.objects.filter(user=user))
if app_settings.EMAIL_VERIFICATION_BY_CODE_ENABLED:
from allauth.account.internal.flows.email_verification_by_code import (
get_pending_verification,
)

verification, _ = get_pending_verification(request, peek=True)
if verification and verification.email_address.user_id == user.pk:
addresses.append(verification.email_address)

return addresses
12 changes: 7 additions & 5 deletions allauth/account/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,14 +124,15 @@ def remove(self):


class EmailConfirmationMixin:
def confirm(self, request):
email_address = self.email_address
def confirm(self, request) -> Optional[EmailAddress]:
email_address = self.email_address # type: ignore[attr-defined]
if not email_address.verified:
confirmed = get_adapter().confirm_email(request, email_address)
if confirmed:
return email_address
return None

def send(self, request=None, signup=False):
def send(self, request=None, signup=False) -> None:
get_adapter().send_confirmation_mail(request, self, signup)
signals.email_confirmation_sent.send(
sender=self.__class__,
Expand Down Expand Up @@ -180,11 +181,12 @@ def key_expired(self):

key_expired.boolean = True # type: ignore[attr-defined]

def confirm(self, request):
def confirm(self, request) -> Optional[EmailAddress]:
if not self.key_expired():
return super().confirm(request)
return None

def send(self, request=None, signup=False):
def send(self, request=None, signup=False) -> None:
super().send(request=request, signup=signup)
self.sent = timezone.now()
self.save()
Expand Down
35 changes: 12 additions & 23 deletions allauth/account/tests/test_email_verification_by_code.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import re
from unittest.mock import patch

from django.conf import settings
Expand All @@ -8,27 +7,9 @@

import pytest

from allauth.account.internal.flows import email_verification_by_code
from allauth.account.models import EmailAddress


@pytest.fixture
def get_last_code(client, mailoutbox):
def f():
code = re.search(
"\n[0-9a-z]{6}\n", mailoutbox[0].body, re.I | re.DOTALL | re.MULTILINE
)[0].strip()
assert (
client.session[
email_verification_by_code.EMAIL_VERIFICATION_CODE_SESSION_KEY
]["code"]
== code
)
return code

return f


@pytest.fixture(autouse=True)
def email_verification_settings(settings):
settings.ACCOUNT_EMAIL_VERIFICATION_BY_CODE_ENABLED = True
Expand All @@ -45,7 +26,13 @@ def email_verification_settings(settings):
],
)
def test_signup(
client, db, settings, password_factory, get_last_code, query, expected_url
client,
db,
settings,
password_factory,
get_last_email_verification_code,
query,
expected_url,
):
password = password_factory()
resp = client.post(
Expand All @@ -58,7 +45,7 @@ def test_signup(
},
)
assert get_user_model().objects.filter(username="johndoe").count() == 1
code = get_last_code()
code = get_last_email_verification_code()
assert resp.status_code == 302
assert resp["location"] == reverse("account_email_verification_sent")
resp = client.get(reverse("account_email_verification_sent"))
Expand Down Expand Up @@ -100,7 +87,9 @@ def test_signup_prevent_enumeration(


@pytest.mark.parametrize("change_email", (False, True))
def test_add_or_change_email(auth_client, user, get_last_code, change_email, settings):
def test_add_or_change_email(
auth_client, user, get_last_email_verification_code, change_email, settings
):
settings.ACCOUNT_CHANGE_EMAIL = change_email
email = "additional@email.org"
assert EmailAddress.objects.filter(user=user).count() == 1
Expand All @@ -113,7 +102,7 @@ def test_add_or_change_email(auth_client, user, get_last_code, change_email, set
assert not email_added_signal.send.called
assert not email_changed_signal.send.called
assert EmailAddress.objects.filter(email=email).count() == 0
code = get_last_code()
code = get_last_email_verification_code()
resp = auth_client.get(reverse("account_email_verification_sent"))
assert resp.status_code == 200
with patch("allauth.account.signals.email_added") as email_added_signal:
Expand Down
20 changes: 20 additions & 0 deletions allauth/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
import random
import re
import time
import uuid
from contextlib import contextmanager
Expand Down Expand Up @@ -324,3 +325,22 @@ def get(self, path):
return request

return RequestFactory()


@pytest.fixture
def get_last_email_verification_code(client, mailoutbox):
from allauth.account.internal.flows import email_verification_by_code

def f():
code = re.search(
"\n[0-9a-z]{6}\n", mailoutbox[0].body, re.I | re.DOTALL | re.MULTILINE
)[0].strip()
assert (
client.session[
email_verification_by_code.EMAIL_VERIFICATION_CODE_SESSION_KEY
]["code"]
== code
)
return code

return f
63 changes: 63 additions & 0 deletions allauth/headless/account/tests/test_email_verification_by_code.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest

from allauth.account import app_settings
from allauth.account.models import EmailAddress
from allauth.headless.constants import Flow


Expand Down Expand Up @@ -106,3 +107,65 @@ def test_email_verification_rate_limits_submitting_codes(
assert resp.status_code == 400
else:
assert resp.status_code == 409


def test_add_email(
auth_client,
user,
email_factory,
headless_reverse,
settings,
get_last_email_verification_code,
):
settings.ACCOUNT_AUTHENTICATION_METHOD = "email"
settings.ACCOUNT_EMAIL_VERIFICATION_BY_CODE_ENABLED = True
settings.ACCOUNT_CHANGE_EMAIL = True
new_email = email_factory()

# Let's add an email...
resp = auth_client.post(
headless_reverse("headless:account:manage_email"),
data={"email": new_email},
content_type="application/json",
)
assert resp.status_code == 200

# It's in the response, albeit unverified.
assert len(resp.json()["data"]) == 2
email_map = {addr["email"]: addr for addr in resp.json()["data"]}
assert not email_map[new_email]["verified"]

# Verify the email with an invalid code.
resp = auth_client.post(
headless_reverse("headless:account:verify_email"),
data={"key": "key"},
content_type="application/json",
)
assert resp.status_code == 400
assert resp.json() == {
"status": 400,
"errors": [
{"message": "Incorrect code.", "code": "incorrect_code", "param": "key"}
],
}

# And with the valid code...
code = get_last_email_verification_code()
resp = auth_client.post(
headless_reverse("headless:account:verify_email"),
data={"key": code},
content_type="application/json",
)
assert resp.status_code == 200
assert resp.json()["data"]["user"]["email"] == new_email

# ACCOUNT_CHANGE_EMAIL = True, so the other one is gone.
assert EmailAddress.objects.filter(user=user).count() == 1

# Re-verification won't work...
resp = auth_client.post(
headless_reverse("headless:account:verify_email"),
data={"key": code},
content_type="application/json",
)
assert resp.status_code == 400
15 changes: 11 additions & 4 deletions allauth/headless/account/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@
from allauth.account import app_settings as account_settings
from allauth.account.adapter import get_adapter as get_account_adapter
from allauth.account.internal import flows
from allauth.account.internal.flows import password_change, password_reset
from allauth.account.models import EmailAddress
from allauth.account.internal.flows import (
manage_email,
password_change,
password_reset,
)
from allauth.account.stages import EmailVerificationStage, LoginStageController
from allauth.account.utils import send_email_confirmation
from allauth.core import ratelimit
Expand Down Expand Up @@ -127,7 +130,11 @@ class VerifyEmailView(APIView):

def handle(self, request, *args, **kwargs):
self.stage = LoginStageController.enter(request, EmailVerificationStage.key)
if not self.stage and account_settings.EMAIL_VERIFICATION_BY_CODE_ENABLED:
if (
not self.stage
and account_settings.EMAIL_VERIFICATION_BY_CODE_ENABLED
and not request.user.is_authenticated
):
return ConflictResponse(request)
return super().handle(request, *args, **kwargs)

Expand Down Expand Up @@ -236,7 +243,7 @@ def get(self, request, *args, **kwargs):
return self._respond_email_list()

def _respond_email_list(self):
addrs = EmailAddress.objects.filter(user=self.request.user)
addrs = manage_email.list_email_addresses(self.request, self.request.user)
return response.EmailAddressesResponse(self.request, addrs)

def post(self, request, *args, **kwargs):
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,5 @@ build-backend = 'setuptools.build_meta'


[tool.bandit]
exclude_dirs = ["tests"]
exclude_dirs = ["tests", "allauth/conftest.py"]
exclude = ["test_*"]

0 comments on commit 45ea40a

Please sign in to comment.