Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reduce the number of SQL queries in updates of groups #255

Merged
merged 4 commits into from
Oct 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 18 additions & 17 deletions django_auth_adfs/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from django.contrib.auth import get_user_model
from django.contrib.auth.backends import ModelBackend
from django.contrib.auth.models import Group
from django.core.exceptions import ImproperlyConfigured, ObjectDoesNotExist, PermissionDenied
from django.core.exceptions import (ImproperlyConfigured, ObjectDoesNotExist,
PermissionDenied)

from django_auth_adfs import signals
from django_auth_adfs.config import provider_config, settings
Expand Down Expand Up @@ -322,27 +323,27 @@ def update_user_groups(self, user, claim_groups):
"""
if settings.GROUPS_CLAIM is not None:
# Update the user's group memberships
django_groups = [group.name for group in user.groups.all()]
user_group_names = user.groups.all().values_list("name", flat=True)

if sorted(claim_groups) != sorted(user_group_names):
# Get the list of already existing groups in one SQL query
existing_claimed_groups = Group.objects.filter(name__in=claim_groups)

if sorted(claim_groups) != sorted(django_groups):
existing_groups = list(Group.objects.filter(name__in=claim_groups).iterator())
existing_group_names = frozenset(group.name for group in existing_groups)
new_groups = []
if settings.MIRROR_GROUPS:
new_groups = [
existing_claimed_group_names = (
group.name for group in existing_claimed_groups
)
# One SQL query by created group.
# bulk_create could have been used here but we want to send signals.
new_claimed_groups = [
Group.objects.get_or_create(name=name)[0]
for name in claim_groups
if name not in existing_group_names
for name in claim_groups if name not in existing_claimed_group_names
]
# Associate the users to all claimed groups
user.groups.set(tuple(existing_claimed_groups) + tuple(new_claimed_groups))
else:
for name in claim_groups:
if name not in existing_group_names:
try:
group = Group.objects.get(name=name)
new_groups.append(group)
except ObjectDoesNotExist:
pass
user.groups.set(existing_groups + new_groups)
# Associate the user to only existing claimed groups
user.groups.set(existing_claimed_groups)

def update_user_flags(self, user, claims, claim_groups):
"""
Expand Down
39 changes: 36 additions & 3 deletions tests/test_authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,22 @@
from django_auth_adfs.exceptions import MFARequired

try:
from urllib.parse import urlparse, parse_qs
from urllib.parse import parse_qs, urlparse
except ImportError: # Python 2.7
from urlparse import urlparse, parse_qs

from copy import deepcopy

from django.contrib.auth.models import User, Group
from django.contrib.auth.models import Group, User
from django.core.exceptions import ObjectDoesNotExist, PermissionDenied
from django.db.models.signals import post_save
from django.test import TestCase, RequestFactory
from django.test import RequestFactory, TestCase
from mock import Mock, patch

from django_auth_adfs import signals
from django_auth_adfs.backend import AdfsAuthCodeBackend
from django_auth_adfs.config import ProviderConfig, Settings

from .models import Profile
from .utils import mock_adfs

Expand Down Expand Up @@ -175,6 +176,38 @@ def test_no_group_claim(self):
self.assertEqual(user.email, "john.doe@example.com")
self.assertEqual(len(user.groups.all()), 0)

@mock_adfs("2016")
def test_group_claim_with_mirror_groups(self):
# Remove one group
Group.objects.filter(name="group1").delete()

backend = AdfsAuthCodeBackend()
with patch("django_auth_adfs.backend.settings.MIRROR_GROUPS", True):
user = backend.authenticate(self.request, authorization_code="dummycode")
self.assertIsInstance(user, User)
self.assertEqual(user.first_name, "John")
self.assertEqual(user.last_name, "Doe")
self.assertEqual(user.email, "john.doe@example.com")
# group1 is restored
group_names = user.groups.order_by("name").values_list("name", flat=True)
self.assertSequenceEqual(group_names, ['group1', 'group2'])

@mock_adfs("2016")
def test_group_claim_without_mirror_groups(self):
# Remove one group
Group.objects.filter(name="group1").delete()

backend = AdfsAuthCodeBackend()
with patch("django_auth_adfs.backend.settings.MIRROR_GROUPS", False):
user = backend.authenticate(self.request, authorization_code="dummycode")
self.assertIsInstance(user, User)
self.assertEqual(user.first_name, "John")
self.assertEqual(user.last_name, "Doe")
self.assertEqual(user.email, "john.doe@example.com")
# User is not added to group1 because the group doesn't exist
group_names = user.groups.values_list("name", flat=True)
self.assertSequenceEqual(group_names, ['group2'])

@mock_adfs("2016", empty_keys=True)
def test_empty_keys(self):
backend = AdfsAuthCodeBackend()
Expand Down