Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reduce the number of SQL queries in updates of groups #255

Merged
merged 4 commits into from
Oct 4, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 15 additions & 17 deletions django_auth_adfs/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from django.contrib.auth import get_user_model
from django.contrib.auth.backends import ModelBackend
from django.contrib.auth.models import Group
from django.core.exceptions import ImproperlyConfigured, ObjectDoesNotExist, PermissionDenied
from django.core.exceptions import (ImproperlyConfigured, ObjectDoesNotExist,
PermissionDenied)

from django_auth_adfs import signals
from django_auth_adfs.config import provider_config, settings
Expand Down Expand Up @@ -322,27 +323,24 @@ def update_user_groups(self, user, claim_groups):
"""
if settings.GROUPS_CLAIM is not None:
# Update the user's group memberships
django_groups = [group.name for group in user.groups.all()]
user_group_names = user.groups.all().values_list("name", flat=True)

if sorted(claim_groups) != sorted(django_groups):
existing_groups = list(Group.objects.filter(name__in=claim_groups).iterator())
existing_group_names = frozenset(group.name for group in existing_groups)
new_groups = []
if sorted(claim_groups) != sorted(user_group_names):
# Get the list of already existing groups in one query
existing_claimed_groups = Group.objects.filter(name__in=claim_groups)
existing_claimed_group_names = (
group.name for group in existing_claimed_groups
)

new_claimed_group_names = (name for name in claim_groups if name not in existing_claimed_group_names)
if settings.MIRROR_GROUPS:
new_groups = [
new_claimed_groups = [
Group.objects.get_or_create(name=name)[0]
for name in claim_groups
if name not in existing_group_names
for name in new_claimed_group_names
]
else:
for name in claim_groups:
if name not in existing_group_names:
try:
group = Group.objects.get(name=name)
new_groups.append(group)
except ObjectDoesNotExist:
pass
user.groups.set(existing_groups + new_groups)
new_claimed_groups = Group.objects.filter(name__in=new_claimed_group_names)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When settings.MIRROR_GROUPS is false, the only groups that should be set for the user are existing_claimed_groups. You can likely move the list comprehension of 335 into the if settings.MIRROR_GROUPS branch and define new_claimed_groups as an empty list in the else branch.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, indeed!
It does't make sense to filter again on new_claimed_group_names because the result won't change.
It avoids one more query, great!

I prefer to move user.groups.set() in the if branches instead of define new_claimed_groups as an empty list.

user.groups.set(tuple(existing_claimed_groups) + tuple(new_claimed_groups))

def update_user_flags(self, user, claims, claim_groups):
"""
Expand Down
23 changes: 20 additions & 3 deletions tests/test_authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,22 @@
from django_auth_adfs.exceptions import MFARequired

try:
from urllib.parse import urlparse, parse_qs
from urllib.parse import parse_qs, urlparse
except ImportError: # Python 2.7
from urlparse import urlparse, parse_qs

from copy import deepcopy

from django.contrib.auth.models import User, Group
from django.contrib.auth.models import Group, User
from django.core.exceptions import ObjectDoesNotExist, PermissionDenied
from django.db.models.signals import post_save
from django.test import TestCase, RequestFactory
from django.test import RequestFactory, TestCase
from mock import Mock, patch

from django_auth_adfs import signals
from django_auth_adfs.backend import AdfsAuthCodeBackend
from django_auth_adfs.config import ProviderConfig, Settings

from .models import Profile
from .utils import mock_adfs

Expand Down Expand Up @@ -175,6 +176,22 @@ def test_no_group_claim(self):
self.assertEqual(user.email, "john.doe@example.com")
self.assertEqual(len(user.groups.all()), 0)

@mock_adfs("2016")
def test_group_claim_with_mirror_groups(self):
# Remove one group
Group.objects.filter(name="group1").delete()

backend = AdfsAuthCodeBackend()
with patch("django_auth_adfs.backend.settings.GROUPS_CLAIM", "group"):
with patch("django_auth_adfs.backend.settings.MIRROR_GROUPS", True):
user = backend.authenticate(self.request, authorization_code="dummycode")
self.assertIsInstance(user, User)
self.assertEqual(user.first_name, "John")
self.assertEqual(user.last_name, "Doe")
self.assertEqual(user.email, "john.doe@example.com")
# Group restored
self.assertEqual(len(user.groups.all()), 2)

@mock_adfs("2016", empty_keys=True)
def test_empty_keys(self):
backend = AdfsAuthCodeBackend()
Expand Down