Skip to content

Commit

Permalink
core: groups: optimize recursive children query (#9931)
Browse files Browse the repository at this point in the history
  • Loading branch information
rissson authored Jun 3, 2024
1 parent 562c52a commit a989390
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 38 deletions.
80 changes: 46 additions & 34 deletions authentik/core/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from django.utils.functional import SimpleLazyObject, cached_property
from django.utils.timezone import now
from django.utils.translation import gettext_lazy as _
from django_cte import CTEQuerySet, With
from guardian.conf import settings
from guardian.mixins import GuardianUserMixin
from model_utils.managers import InheritanceManager
Expand Down Expand Up @@ -56,6 +57,8 @@
"authentik_used_by_shadows",
)

GROUP_RECURSION_LIMIT = 20


def default_token_duration() -> datetime:
"""Default duration a Token is valid"""
Expand Down Expand Up @@ -96,6 +99,40 @@ class UserTypes(models.TextChoices):
INTERNAL_SERVICE_ACCOUNT = "internal_service_account"


class GroupQuerySet(CTEQuerySet):
def with_children_recursive(self):
"""Recursively get all groups that have the current queryset as parents
or are indirectly related."""

def make_cte(cte):
"""Build the query that ends up in WITH RECURSIVE"""
# Start from self, aka the current query
# Add a depth attribute to limit the recursion
return self.annotate(
relative_depth=models.Value(0, output_field=models.IntegerField())
).union(
# Here is the recursive part of the query. cte refers to the previous iteration
# Only select groups for which the parent is part of the previous iteration
# and increase the depth
# Finally, limit the depth
cte.join(Group, group_uuid=cte.col.parent_id)
.annotate(
relative_depth=models.ExpressionWrapper(
cte.col.relative_depth
+ models.Value(1, output_field=models.IntegerField()),
output_field=models.IntegerField(),
)
)
.filter(relative_depth__lt=GROUP_RECURSION_LIMIT),
all=True,
)

# Build the recursive query, see above
cte = With.recursive(make_cte)
# Return the result, as a usable queryset for Group.
return cte.join(Group, group_uuid=cte.col.group_uuid).with_cte(cte)


class Group(SerializerModel):
"""Group model which supports a basic hierarchy and has attributes"""

Expand All @@ -118,6 +155,8 @@ class Group(SerializerModel):
)
attributes = models.JSONField(default=dict, blank=True)

objects = GroupQuerySet.as_manager()

@property
def serializer(self) -> Serializer:
from authentik.core.api.groups import GroupSerializer
Expand All @@ -136,36 +175,11 @@ def is_member(self, user: "User") -> bool:
return user.all_groups().filter(group_uuid=self.group_uuid).exists()

def children_recursive(self: Self | QuerySet["Group"]) -> QuerySet["Group"]:
"""Recursively get all groups that have this as parent or are indirectly related"""
direct_groups = []
if isinstance(self, QuerySet):
direct_groups = list(x for x in self.all().values_list("pk", flat=True).iterator())
else:
direct_groups = [self.pk]
if len(direct_groups) < 1:
return Group.objects.none()
query = """
WITH RECURSIVE parents AS (
SELECT authentik_core_group.*, 0 AS relative_depth
FROM authentik_core_group
WHERE authentik_core_group.group_uuid = ANY(%s)
UNION ALL
SELECT authentik_core_group.*, parents.relative_depth + 1
FROM authentik_core_group, parents
WHERE (
authentik_core_group.group_uuid = parents.parent_id and
parents.relative_depth < 20
)
)
SELECT group_uuid
FROM parents
GROUP BY group_uuid, name
ORDER BY name;
"""
group_pks = [group.pk for group in Group.objects.raw(query, [direct_groups]).iterator()]
return Group.objects.filter(pk__in=group_pks)
"""Compatibility layer for Group.objects.with_children_recursive()"""
qs = self
if not isinstance(self, QuerySet):
qs = Group.objects.filter(group_uuid=self.group_uuid)
return qs.with_children_recursive()

def __str__(self):
return f"Group {self.name}"
Expand Down Expand Up @@ -232,10 +246,8 @@ def default_path() -> str:
return User._meta.get_field("path").default

def all_groups(self) -> QuerySet[Group]:
"""Recursively get all groups this user is a member of.
At least one query is done to get the direct groups of the user, with groups
there are at most 3 queries done"""
return Group.children_recursive(self.ak_groups.all())
"""Recursively get all groups this user is a member of."""
return self.ak_groups.all().with_children_recursive()

def group_attributes(self, request: HttpRequest | None = None) -> dict[str, Any]:
"""Get a dictionary containing the attributes from all groups the user belongs to,
Expand Down
5 changes: 3 additions & 2 deletions authentik/rbac/signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,9 @@ def rbac_group_role_m2m(
if action not in ["post_add", "post_remove", "post_clear"]:
return
with atomic():
group_users = list(
instance.children_recursive()
group_users = (
Group.objects.filter(group_uuid=instance.group_uuid)
.with_children_recursive()
.exclude(users__isnull=True)
.values_list("users", flat=True)
)
Expand Down
15 changes: 13 additions & 2 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ dacite = "*"
deepmerge = "*"
defusedxml = "*"
django = "*"
django-cte = "*"
django-filter = "*"
django-guardian = "*"
django-model-utils = "*"
Expand Down

0 comments on commit a989390

Please sign in to comment.