Skip to content

Commit

Permalink
refactor: move project_id query param filtering to get_queryset o…
Browse files Browse the repository at this point in the history
…n relevant ViewSets
  • Loading branch information
annehaley committed Sep 30, 2024
1 parent 2e66d31 commit 305f0b2
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 23 deletions.
18 changes: 2 additions & 16 deletions uvdat/core/rest/access_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,27 +57,13 @@ def has_object_permission(self, request, view, obj):

class GuardianFilter(BaseFilterBackend):
def filter_queryset(self, request, queryset, view):
project_id = request.query_params.get('project')
if project_id is not None:
try:
project_id = int(project_id)
except ValueError:
project_id = None

if request.user.is_superuser and project_id is None:
if request.user.is_superuser:
return queryset

# Allow user to have any level of permission
all_perms = [x for x, _ in Project._meta.permissions]
user_projects = get_objects_for_user(
klass=(
models.Project
if project_id is None
else models.Project.objects.filter(id=project_id)
),
user=request.user,
perms=all_perms,
any_perm=True,
klass=models.Project, user=request.user, perms=all_perms, any_perm=True
)

# Return queryset filtered by objects that are within these projects
Expand Down
8 changes: 8 additions & 0 deletions uvdat/core/rest/chart.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,14 @@ class ChartViewSet(ModelViewSet):
filter_backends = [GuardianFilter]
lookup_field = 'id'

def get_queryset(self):
qs = super().get_queryset()
project_id: str = self.request.query_params.get('project')
if project_id is None or not project_id.isdigit():
return qs

return qs.filter(project=int(project_id))

def validate_editable(self, chart, func, *args, **kwargs):
if chart.editable:
return func(*args, **kwargs)
Expand Down
8 changes: 8 additions & 0 deletions uvdat/core/rest/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,14 @@ class DatasetViewSet(ModelViewSet):
filter_backends = [GuardianFilter]
lookup_field = 'id'

def get_queryset(self):
qs = super().get_queryset()
project_id: str = self.request.query_params.get('project')
if project_id is None or not project_id.isdigit():
return qs

return qs.filter(project=int(project_id))

@action(detail=True, methods=['get'])
def map_layers(self, request, **kwargs):
dataset: Dataset = self.get_object()
Expand Down
15 changes: 8 additions & 7 deletions uvdat/core/rest/regions.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,19 +32,20 @@ class DerivedRegionViewSet(mixins.RetrieveModelMixin, mixins.ListModelMixin, Gen
filter_backends = [GuardianFilter]
lookup_field = 'id'

def get_queryset(self):
qs = super().get_queryset()
project_id: str = self.request.query_params.get('project')
if project_id is None or not project_id.isdigit():
return qs

return qs.filter(project=int(project_id))

def get_serializer_class(self):
if self.detail:
return DerivedRegionDetailSerializer

return super().get_serializer_class()

def get_queryset(self):
project_id = self.request.query_params.get('project')
if project_id:
return DerivedRegion.objects.filter(project__id=project_id)
else:
return DerivedRegion.objects.all()

@action(detail=True, methods=['GET'])
def as_feature(self, request, *args, **kwargs):
obj: DerivedRegion = self.get_object()
Expand Down

0 comments on commit 305f0b2

Please sign in to comment.