From 305f0b2d6d94a985c46579799257a3e9c115f51d Mon Sep 17 00:00:00 2001 From: Anne Haley Date: Mon, 30 Sep 2024 23:25:23 +0000 Subject: [PATCH] refactor: move `project_id` query param filtering to `get_queryset` on relevant ViewSets --- uvdat/core/rest/access_control.py | 18 ++---------------- uvdat/core/rest/chart.py | 8 ++++++++ uvdat/core/rest/dataset.py | 8 ++++++++ uvdat/core/rest/regions.py | 15 ++++++++------- 4 files changed, 26 insertions(+), 23 deletions(-) diff --git a/uvdat/core/rest/access_control.py b/uvdat/core/rest/access_control.py index dbc52970..684d2394 100644 --- a/uvdat/core/rest/access_control.py +++ b/uvdat/core/rest/access_control.py @@ -57,27 +57,13 @@ def has_object_permission(self, request, view, obj): class GuardianFilter(BaseFilterBackend): def filter_queryset(self, request, queryset, view): - project_id = request.query_params.get('project') - if project_id is not None: - try: - project_id = int(project_id) - except ValueError: - project_id = None - - if request.user.is_superuser and project_id is None: + if request.user.is_superuser: return queryset # Allow user to have any level of permission all_perms = [x for x, _ in Project._meta.permissions] user_projects = get_objects_for_user( - klass=( - models.Project - if project_id is None - else models.Project.objects.filter(id=project_id) - ), - user=request.user, - perms=all_perms, - any_perm=True, + klass=models.Project, user=request.user, perms=all_perms, any_perm=True ) # Return queryset filtered by objects that are within these projects diff --git a/uvdat/core/rest/chart.py b/uvdat/core/rest/chart.py index 12cf5c81..12822e28 100644 --- a/uvdat/core/rest/chart.py +++ b/uvdat/core/rest/chart.py @@ -14,6 +14,14 @@ class ChartViewSet(ModelViewSet): filter_backends = [GuardianFilter] lookup_field = 'id' + def get_queryset(self): + qs = super().get_queryset() + project_id: str = self.request.query_params.get('project') + if project_id is None or not project_id.isdigit(): + return qs + + return qs.filter(project=int(project_id)) + def validate_editable(self, chart, func, *args, **kwargs): if chart.editable: return func(*args, **kwargs) diff --git a/uvdat/core/rest/dataset.py b/uvdat/core/rest/dataset.py index 24b24402..1ef7eaf9 100644 --- a/uvdat/core/rest/dataset.py +++ b/uvdat/core/rest/dataset.py @@ -24,6 +24,14 @@ class DatasetViewSet(ModelViewSet): filter_backends = [GuardianFilter] lookup_field = 'id' + def get_queryset(self): + qs = super().get_queryset() + project_id: str = self.request.query_params.get('project') + if project_id is None or not project_id.isdigit(): + return qs + + return qs.filter(project=int(project_id)) + @action(detail=True, methods=['get']) def map_layers(self, request, **kwargs): dataset: Dataset = self.get_object() diff --git a/uvdat/core/rest/regions.py b/uvdat/core/rest/regions.py index 00781227..d3a6496c 100644 --- a/uvdat/core/rest/regions.py +++ b/uvdat/core/rest/regions.py @@ -32,19 +32,20 @@ class DerivedRegionViewSet(mixins.RetrieveModelMixin, mixins.ListModelMixin, Gen filter_backends = [GuardianFilter] lookup_field = 'id' + def get_queryset(self): + qs = super().get_queryset() + project_id: str = self.request.query_params.get('project') + if project_id is None or not project_id.isdigit(): + return qs + + return qs.filter(project=int(project_id)) + def get_serializer_class(self): if self.detail: return DerivedRegionDetailSerializer return super().get_serializer_class() - def get_queryset(self): - project_id = self.request.query_params.get('project') - if project_id: - return DerivedRegion.objects.filter(project__id=project_id) - else: - return DerivedRegion.objects.all() - @action(detail=True, methods=['GET']) def as_feature(self, request, *args, **kwargs): obj: DerivedRegion = self.get_object()