Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace get_user_model calls with Member #281

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions src/authentication/basic_auth.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from django.contrib.auth import authenticate, get_user_model, password_validation
from django.contrib.auth import authenticate, password_validation
from rest_framework.exceptions import ValidationError
from rest_framework.status import HTTP_400_BAD_REQUEST, HTTP_401_UNAUTHORIZED

from authentication.providers import LoginProvider, RegistrationProvider, TokenProvider
from backend.exceptions import FormattedException
from backend.signals import login, login_reject
from member.models import Member


class BasicAuthRegistrationProvider(RegistrationProvider):
Expand All @@ -22,7 +23,7 @@ def validate(self, data):
return {key: data[key] for key in self.required_fields}

def register_user(self, username, email, password, **kwargs):
user = get_user_model()(username=username, email=email)
user = Member(username=username, email=email)

try:
password_validation.validate_password(password, user)
Expand Down
1 change: 0 additions & 1 deletion src/authentication/providers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import abc
import re

from django.contrib.auth import get_user_model
from django.core.validators import EmailValidator
from rest_framework.exceptions import ValidationError

Expand Down
11 changes: 6 additions & 5 deletions src/authentication/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from anymail.exceptions import AnymailAPIError
from django.conf import settings
from django.contrib.auth import get_user_model, password_validation
from django.contrib.auth import password_validation
from django.utils import timezone
from rest_framework import serializers
from rest_framework.generics import get_object_or_404
Expand All @@ -19,6 +19,7 @@
from backend.mail import send_email
from backend.signals import register
from config import config
from member.models import Member
from plugins import providers
from team.models import Team

Expand Down Expand Up @@ -62,7 +63,7 @@ def validate(self, _):
def create(self, validated_data):
user = providers.get_provider("registration").register_user(**validated_data, context=self.context)

if not get_user_model().objects.all().exists():
if not Member.objects.all().exists():
user.is_staff = True
user.is_superuser = True

Expand Down Expand Up @@ -141,7 +142,7 @@ def validate(self, data):
uid = data.get("uid")
token = data.get("token")
password = data.get("password")
user = get_object_or_404(get_user_model(), id=uid)
user = get_object_or_404(Member, id=uid)
reset_token = get_object_or_404(PasswordResetToken, token=token, user_id=uid, expires__gt=timezone.now())
password_validation.validate_password(password, reset_token)
data["user"] = user
Expand All @@ -156,7 +157,7 @@ class EmailVerificationSerializer(serializers.Serializer):
def validate(self, data):
uid = int(data.get("uid"))
token = data.get("token")
user = get_object_or_404(get_user_model(), id=uid, email_token=token)
user = get_object_or_404(Member, id=uid, email_token=token)
if user.email_verified:
raise serializers.ValidationError("email_is_already_verified")
data["user"] = user
Expand All @@ -167,7 +168,7 @@ class EmailSerializer(serializers.Serializer):
email = serializers.EmailField()

def validate(self, data):
user = get_object_or_404(get_user_model(), email=data.get("email"))
user = get_object_or_404(Member, email=data.get("email"))
if user.email_verified:
raise serializers.ValidationError("email_is_already_verified")
data["user"] = user
Expand Down
74 changes: 37 additions & 37 deletions src/authentication/tests.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from unittest import mock

import pyotp
from django.contrib.auth import get_user_model
from django.http import HttpRequest
from django.urls import reverse
from rest_framework.request import Request
Expand Down Expand Up @@ -37,6 +36,7 @@
VerifyTwoFactorView,
)
from config import config
from member.models import Member
from team.models import Team


Expand Down Expand Up @@ -130,7 +130,7 @@ def test_register_admin(self):
"email": "user6@example.org",
}
self.client.post(reverse("register"), data)
self.assertTrue(get_user_model().objects.filter(username=data["username"]).first().is_staff)
self.assertTrue(Member.objects.filter(username=data["username"]).first().is_staff)

def test_register_second(self):
data = {
Expand All @@ -145,7 +145,7 @@ def test_register_second(self):
"email": "user7@example.org",
}
self.client.post(reverse("register"), data)
self.assertFalse(get_user_model().objects.filter(username=data["username"]).first().is_staff)
self.assertFalse(Member.objects.filter(username=data["username"]).first().is_staff)

def test_register_malformed(self):
data = {
Expand Down Expand Up @@ -173,7 +173,7 @@ def test_register_teams_disabled(self):
response = self.client.post(reverse("register"), data)
config.set("enable_teams", True)
self.assertEqual(response.status_code, HTTP_201_CREATED)
self.assertEqual(get_user_model().objects.get(username="user10").team.name, "user10")
self.assertEqual(Member.objects.get(username="user10").team.name, "user10")

def test_register_with_mail_passing_regex(self):
with self.settings(
Expand Down Expand Up @@ -265,14 +265,14 @@ def test_register_duplicate_username_different_casing(self):
class EmailResendTestCase(APITestCase):
def test_email_resend(self):
with self.settings(RATELIMIT_ENABLE=False):
user = get_user_model()(username="test_verify_user", email_verified=False, email="tvu@example.com")
user = Member(username="test_verify_user", email_verified=False, email="tvu@example.com")
user.save()
response = self.client.post(reverse("resend-email"), {"email": "tvu@example.com"})
self.assertEqual(response.status_code, HTTP_200_OK)

def test_already_verified_email_resend(self):
with self.settings(RATELIMIT_ENABLE=False):
user = get_user_model()(username="resend-email", email_verified=True, email="tvu@example.com")
user = Member(username="resend-email", email_verified=True, email="tvu@example.com")
user.save()
response = self.client.post(reverse("resend-email"), {"email": "tvu@example.com"})
self.assertEqual(response.status_code, HTTP_400_BAD_REQUEST)
Expand All @@ -285,9 +285,9 @@ def test_non_existing_email_resend(self):

class SudoTestCase(APITestCase):
def test_sudo(self):
user = get_user_model()(username="sudotest", is_staff=True, email="sudotest@example.com", is_superuser=True)
user = Member(username="sudotest", is_staff=True, email="sudotest@example.com", is_superuser=True)
user.save()
user2 = get_user_model()(username="sudotest2", email="sudotest2@example.com")
user2 = Member(username="sudotest2", email="sudotest2@example.com")
user2.save()

self.client.force_authenticate(user)
Expand All @@ -297,7 +297,7 @@ def test_sudo(self):

class DesudoTestCase(APITestCase):
def test_desudo(self):
user2 = get_user_model()(username="sudotest2", email="sudotest2@example.com")
user2 = Member(username="sudotest2", email="sudotest2@example.com")
user2.save()

request = Request(HttpRequest())
Expand All @@ -307,7 +307,7 @@ def test_desudo(self):
self.assertTrue("token" in response.data["d"])

def test_desudo_staff(self):
user2 = get_user_model()(username="sudotest2", email="sudotest2@example.com")
user2 = Member(username="sudotest2", email="sudotest2@example.com")
user2.is_staff = True
user2.save()

Expand All @@ -320,15 +320,15 @@ def test_desudo_staff(self):

class GenerateInvitesTestCase(APITestCase):
def test_response_length(self):
user = get_user_model()(username="resend-email", is_staff=True, email="tvu@example.com", is_superuser=True)
user = Member(username="resend-email", is_staff=True, email="tvu@example.com", is_superuser=True)
user.save()
self.client.force_authenticate(user=user)
team = Team.objects.create(owner=user, name=user.username, password="123123")
response = self.client.post(reverse("generate-invites"), {"amount": 15, "auto_team": team.id, "max_uses": 1})
self.assertEqual(len(response.data["d"]["invite_codes"]), 15)

def test_invites_viewset(self):
user = get_user_model()(username="resend-email", is_staff=True, email="tvu@example.com", is_superuser=True)
user = Member(username="resend-email", is_staff=True, email="tvu@example.com", is_superuser=True)
user.save()
self.client.force_authenticate(user=user)
self.client.post(reverse("generate-invites"), {"amount": 15, "max_uses": 1})
Expand All @@ -343,7 +343,7 @@ def setUp(self):
InviteCode(code="test1", max_uses=10).save()
InviteCode(code="test2", max_uses=1).save()
InviteCode(code="test3", max_uses=1).save()
user = get_user_model()(
user = Member(
username="invtestadmin",
email="invtestadmin@example.org",
email_verified=True,
Expand Down Expand Up @@ -425,12 +425,12 @@ def test_register_invite_required_auto_team(self):
"invite": "test4",
}
self.client.post(reverse("register"), data)
self.assertEqual(get_user_model().objects.get(username="user12").team.id, self.team.id)
self.assertEqual(Member.objects.get(username="user12").team.id, self.team.id)


class LogoutTestCase(APITestCase):
def setUp(self):
user = get_user_model()(username="logout-test", email="logout-test@example.org")
user = Member(username="logout-test", email="logout-test@example.org")
user.set_password("password")
user.email_verified = True
user.save()
Expand All @@ -449,7 +449,7 @@ def test_logout_not_logged_in(self):

class LoginTestCase(APITestCase):
def setUp(self):
user = get_user_model()(username="login-test", email="login-test@example.org")
user = Member(username="login-test", email="login-test@example.org")
user.set_password("password")
user.email_verified = True
user.save()
Expand Down Expand Up @@ -561,7 +561,7 @@ def test_login_2fa_required(self):

class Login2FATestCase(APITestCase):
def setUp(self):
user = get_user_model()(username="login-test", email="login-test@example.org")
user = Member(username="login-test", email="login-test@example.org")
user.set_password("password")
user.email_verified = True
user.save()
Expand Down Expand Up @@ -590,7 +590,7 @@ def test_login_2fa_invalid(self):
self.assertEqual(response.status_code, HTTP_401_UNAUTHORIZED)

def test_login_2fa_without_2fa(self):
user = get_user_model()(username="login-test-no-2fa", email="login-test-no-2fa@example.org")
user = Member(username="login-test-no-2fa", email="login-test-no-2fa@example.org")
user.set_password("password")
user.email_verified = True
user.save()
Expand Down Expand Up @@ -638,13 +638,13 @@ def test_login_2fa_invalid_code(self):

class TokenTestCase(APITestCase):
def test_token_str(self):
user = get_user_model()(username="token-test", email="token-test@example.org")
user = Member(username="token-test", email="token-test@example.org")
user.save()
tok = Token(key="a" * 40, user=user)
self.assertEqual(str(tok), "a" * 40)

def test_token_preserves_key(self):
user = get_user_model()(username="token-test-2", email="token-test-2@example.org")
user = Member(username="token-test-2", email="token-test-2@example.org")
user.save()
token = Token(key="a" * 40, user=user)
token.save()
Expand All @@ -653,7 +653,7 @@ def test_token_preserves_key(self):

class TFATestCase(APITestCase):
def setUp(self):
user = get_user_model()(username="2fa-test", email="2fa-test@example.org")
user = Member(username="2fa-test", email="2fa-test@example.org")
user.set_password("password")
user.email_verified = True
user.save()
Expand Down Expand Up @@ -704,37 +704,37 @@ def test_add_2fa_with_2fa(self):
def test_remove_2fa(self):
self.client.force_authenticate(user=self.user)
self.client.post(reverse("add-2fa"))
totp_device = get_user_model().objects.get(id=self.user.id).totp_device
totp_device = Member.objects.get(id=self.user.id).totp_device
totp_device.verified = True
totp_device.save()
self.client.force_authenticate(user=get_user_model().objects.get(id=self.user.id))
self.client.force_authenticate(user=Member.objects.get(id=self.user.id))
response = self.client.post(reverse("remove-2fa"), data={"otp": pyotp.TOTP(totp_device.totp_secret).now()})
self.assertEqual(response.status_code, HTTP_200_OK)

def test_remove_2fa_fail(self):
self.client.force_authenticate(user=self.user)
self.client.post(reverse("add-2fa"))
totp_device = get_user_model().objects.get(id=self.user.id).totp_device
totp_device = Member.objects.get(id=self.user.id).totp_device
totp_device.verified = True
totp_device.save()
self.client.force_authenticate(user=get_user_model().objects.get(id=self.user.id))
self.client.force_authenticate(user=Member.objects.get(id=self.user.id))
response = self.client.post(reverse("remove-2fa"), data={"otp": "invalid_otp"})
self.assertEqual(response.status_code, HTTP_401_UNAUTHORIZED)

def test_remove_2fa_removes_2fa(self):
self.client.force_authenticate(user=self.user)
self.client.post(reverse("add-2fa"))
totp_device = get_user_model().objects.get(id=self.user.id).totp_device
totp_device = Member.objects.get(id=self.user.id).totp_device
totp_device.verified = True
totp_device.save()
self.client.force_authenticate(user=get_user_model().objects.get(id=self.user.id))
self.client.force_authenticate(user=Member.objects.get(id=self.user.id))
self.client.post(reverse("remove-2fa"), data={"otp": pyotp.TOTP(totp_device.totp_secret).now()})
self.assertFalse(get_user_model().objects.get(id=self.user.id).has_2fa())
self.assertFalse(Member.objects.get(id=self.user.id).has_2fa())

def test_remove_2fa_no_2fa(self):
self.client.force_authenticate(user=self.user)
self.client.post(reverse("add-2fa"))
user = get_user_model().objects.get(id=self.user.id)
user = Member.objects.get(id=self.user.id)
user.totp_device = None
user.save()
response = self.client.post(reverse("remove-2fa"))
Expand All @@ -753,14 +753,14 @@ def test_password_reset_request_valid(self):
with self.settings(
MAIL={"SEND_ADDRESS": "no-reply@ractf.co.uk", "SEND_NAME": "RACTF", "SEND": True, "SEND_MODE": "SES"}
):
get_user_model()(username="test-password-rest", email="user10@example.org", email_verified=True).save()
Member(username="test-password-rest", email="user10@example.org", email_verified=True).save()
response = self.client.post(reverse("request-password-reset"), data={"email": "user10@example.org"})
self.assertEqual(response.status_code, HTTP_200_OK)


class DoPasswordResetTestCase(APITestCase):
def setUp(self):
user = get_user_model()(username="pr-test", email="pr-test@example.org")
user = Member(username="pr-test", email="pr-test@example.org")
user.set_password("password")
user.email_verified = True
user.save()
Expand Down Expand Up @@ -841,7 +841,7 @@ def test_password_reset_cant_login_yet(self, obj):

class VerifyEmailTestCase(APITestCase):
def setUp(self):
user = get_user_model()(username="ev-test", email="ev-test@example.org")
user = Member(username="ev-test", email="ev-test@example.org")
user.set_password("password")
user.save()
self.user = user
Expand Down Expand Up @@ -894,7 +894,7 @@ def test_email_verify_bad_token(self):

class ChangePasswordTestCase(APITestCase):
def setUp(self):
user = get_user_model()(username="cp-test", email="cp-test@example.org")
user = Member(username="cp-test", email="cp-test@example.org")
user.set_password("password")
user.save()
self.user = user
Expand Down Expand Up @@ -931,7 +931,7 @@ def test_change_password_invalid_old(self):

class RegerateBackupCodesTestCase(APITestCase):
def setUp(self):
user = get_user_model()(username="backupcode-test", email="backupcode-test@example.org")
user = Member(username="backupcode-test", email="backupcode-test@example.org")
user.set_password("password")
user.save()
TOTPDevice(user=user, verified=True).save()
Expand All @@ -955,17 +955,17 @@ def test_regenerate_backup_codes_unique(self):
self.assertFalse(set(first_response.data["d"]["backup_codes"]) & set(second_response.data["d"]["backup_codes"]))

def test_regenerate_backup_codes_no_2fa(self):
user = get_user_model().objects.get(id=self.user.id)
user = Member.objects.get(id=self.user.id)
user.totp_device.delete()
user.save()
self.client.force_authenticate(user=get_user_model().objects.get(id=self.user.id))
self.client.force_authenticate(user=Member.objects.get(id=self.user.id))
response = self.client.post(reverse("regenerate-backup-codes"))
self.assertEqual(response.status_code, HTTP_403_FORBIDDEN)


class CreateBotUserTestCase(APITestCase):
def setUp(self):
user = get_user_model()(username="bot-test", email="bot-test@example.org", is_staff=True, is_superuser=True)
user = Member(username="bot-test", email="bot-test@example.org", is_staff=True, is_superuser=True)
user.set_password("password")
user.save()
self.user = user
Expand Down
Loading