Skip to content

Commit

Permalink
refactor: update guardian.py
Browse files Browse the repository at this point in the history
  • Loading branch information
annehaley committed Sep 30, 2024
1 parent 61f9767 commit 9989e6f
Showing 1 changed file with 35 additions and 72 deletions.
107 changes: 35 additions & 72 deletions uvdat/core/rest/guardian.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from django.contrib.auth.models import User
from django.db.models import Model
from django.db.models.query import QuerySet
from guardian.shortcuts import get_objects_for_user
from rest_framework.filters import BaseFilterBackend
from rest_framework.permissions import SAFE_METHODS, IsAuthenticated
Expand All @@ -6,94 +9,54 @@


# TODO: Dataset permissions should be separated from Project permissions
def filter_by_parent_perms(klass, user, perms=None, queryset=None, obj=None):
objects = None
if queryset is not None:
objects = queryset
elif obj is not None:
objects = klass.objects.filter(id=obj.id)
def filter_queryset_by_project_permission(
queryset: QuerySet[Model], user: User, perms: list[str] | None = None
):
if perms is None:
perms = ['follower', 'collaborator', 'owner']

parent_queryset = None
filter_function = None
if klass == models.Project:
parent_queryset = objects
filter_function = lambda parents: parents
elif klass == models.Chart:
parent_queryset = models.Project.objects.filter(charts__in=objects).distinct()
filter_function = lambda parents: objects.filter(project__in=parents).distinct()
elif klass == models.Dataset:
parent_queryset = models.Project.objects.filter(datasets__in=objects).distinct()
filter_function = lambda parents: objects.filter(project__in=parents).distinct()
elif klass == models.FileItem:
parent_queryset = models.Project.objects.filter(
datasets__source_files__in=objects
).distinct()
filter_function = lambda parents: objects.filter(dataset__project__in=parents).distinct()
elif klass == models.RasterMapLayer:
parent_queryset = models.Project.objects.filter(
datasets__rastermaplayer__in=objects
).distinct()
filter_function = lambda parents: objects.filter(dataset__project__in=parents).distinct()
elif klass == models.VectorMapLayer:
parent_queryset = models.Project.objects.filter(
datasets__vectormaplayer__in=objects
).distinct()
filter_function = lambda parents: objects.filter(dataset__project__in=parents).distinct()
# TODO: Add clause for VectorFeature when an API Viewset is added for that model
elif klass == models.SourceRegion:
parent_queryset = models.Project.objects.filter(datasets__regions__in=objects).distinct()
filter_function = lambda parents: objects.filter(dataset__project__in=parents).distinct()
elif klass == models.DerivedRegion:
parent_queryset = models.Project.objects.filter(derived_regions__in=objects).distinct()
filter_function = lambda parents: objects.filter(project__in=parents).distinct()
elif klass == models.Network:
parent_queryset = models.Project.objects.filter(datasets__networks__in=objects).distinct()
filter_function = lambda parents: objects.filter(dataset__project__in=parents).distinct()
elif klass == models.NetworkEdge:
parent_queryset = models.Project.objects.filter(
datasets__networks__edges__in=objects
).distinct()
filter_function = lambda parents: objects.filter(
network__dataset__project__in=parents
).distinct()
elif klass == models.NetworkNode:
parent_queryset = models.Project.objects.filter(
datasets__networks__nodes__in=objects
).distinct()
filter_function = lambda parents: objects.filter(
network__dataset__project__in=parents
).distinct()

if parent_queryset is not None and filter_function is not None:
allowed_parents = get_objects_for_user(
klass=parent_queryset,
user=user,
perms=perms,
any_perm=True,
)
allowed_children = filter_function(allowed_parents)
return allowed_children
return klass.objects.none()
# Get all projects a user has access to
user_projects = get_objects_for_user(
klass=models.Project, user=user, perms=perms, any_perm=True
)
model = queryset.model
if model == models.Project:
return queryset.filter(id__in=user_projects.values_list('id', flat=True))
if model in [models.Dataset, models.Chart, models.DerivedRegion]:
return queryset.filter(project__in=user_projects)
if model in [
models.FileItem,
models.RasterMapLayer,
models.VectorMapLayer,
models.Network,
models.SourceRegion,
]:
return queryset.filter(dataset__project__in=user_projects)
if model in [models.NetworkNode, models.NetworkEdge]:
return queryset.filter(network__dataset__project__in=user_projects)
# If any models are un-caught, raise an exception
raise NotImplementedError


class GuardianPermission(IsAuthenticated):
def has_object_permission(self, request, view, obj):
if request.user.is_superuser:
return True

perms = ['follower', 'collaborator', 'owner']
if request.method not in SAFE_METHODS:
perms = ['collaborator', 'owner']
if request.method == 'DELETE':
perms = ['owner']
allowed_objects = filter_by_parent_perms(obj.__class__, request.user, perms=perms, obj=obj)
return allowed_objects.filter(id=obj.id).exists()
if not isinstance(obj, Model):
raise Exception('Only Django models may be used in permission check')
# Create queryset out of single object, so it can be passed to the filter function
queryset = obj.__class__.objects.filter(pk=obj.pk)
# If the object remains in the queryset after this function filters it, then the user has
# the required permission on at least one associated project
return filter_queryset_by_project_permission(queryset, request.user, perms).exists()


class GuardianFilter(BaseFilterBackend):
def filter_queryset(self, request, queryset, view):
if request.user.is_superuser:
return queryset
return filter_by_parent_perms(queryset.model, request.user, queryset=queryset)
return filter_queryset_by_project_permission(queryset, request.user)

0 comments on commit 9989e6f

Please sign in to comment.