From 6943cb4696cfb4de250fd55e924448d742b0aa33 Mon Sep 17 00:00:00 2001 From: Anne Haley Date: Wed, 14 Aug 2024 13:37:05 +0000 Subject: [PATCH] fix: address some server-side API bugs --- uvdat/core/rest/chart.py | 1 + uvdat/core/rest/dataset.py | 1 + uvdat/core/rest/file_item.py | 1 + uvdat/core/rest/filter.py | 15 ++++++--------- uvdat/core/rest/map_layers.py | 2 ++ uvdat/core/rest/network.py | 3 +++ uvdat/core/rest/project.py | 4 ++++ uvdat/core/rest/regions.py | 2 ++ uvdat/core/rest/serializers.py | 14 +++++++++++--- uvdat/core/rest/simulations.py | 1 + 10 files changed, 32 insertions(+), 12 deletions(-) diff --git a/uvdat/core/rest/chart.py b/uvdat/core/rest/chart.py index 20f6818a..28b18f20 100644 --- a/uvdat/core/rest/chart.py +++ b/uvdat/core/rest/chart.py @@ -11,6 +11,7 @@ class ChartViewSet(ModelViewSet): queryset = Chart.objects.all() serializer_class = ChartSerializer filter_backends = [AccessControl] + lookup_field = "id" def validate_editable(self, chart, func, *args, **kwargs): if chart.editable: diff --git a/uvdat/core/rest/dataset.py b/uvdat/core/rest/dataset.py index 1e6f88c9..bd2559e2 100644 --- a/uvdat/core/rest/dataset.py +++ b/uvdat/core/rest/dataset.py @@ -21,6 +21,7 @@ class DatasetViewSet(ModelViewSet): queryset = Dataset.objects.all() serializer_class = DatasetSerializer filter_backends = [AccessControl] + lookup_field = "id" @action(detail=True, methods=['get']) def map_layers(self, request, **kwargs): diff --git a/uvdat/core/rest/file_item.py b/uvdat/core/rest/file_item.py index 8233a8c4..a0412a23 100644 --- a/uvdat/core/rest/file_item.py +++ b/uvdat/core/rest/file_item.py @@ -9,3 +9,4 @@ class FileItemViewSet(ModelViewSet): queryset = FileItem.objects.all() serializer_class = FileItemSerializer filter_backends = [AccessControl] + lookup_field = "id" diff --git a/uvdat/core/rest/filter.py b/uvdat/core/rest/filter.py index 1e179a8e..0cf5cfe8 100644 --- a/uvdat/core/rest/filter.py +++ b/uvdat/core/rest/filter.py @@ -1,23 +1,20 @@ from rest_framework.filters import BaseFilterBackend -from uvdat.core.models import Project - class AccessControl(BaseFilterBackend): def filter_queryset(self, request, queryset, view): project_id = request.query_params.get('project') user = request.user + valid_ids = [o.id for o in queryset] if project_id: - project = Project.objects.get(id=project_id) - queryset = [o for o in queryset if o.is_in_project(project)] + valid_ids = [o.id for o in queryset if o.is_in_project(project_id)] if request.method == 'GET': - queryset = [o for o in queryset if o.readable_by(user)] + valid_ids = [o.id for o in queryset if o.readable_by(user) and o.id in valid_ids] elif request.method == 'PUT' or request.method == 'PATCH': - queryset = [o for o in queryset if o.editable_by(user)] + valid_ids = [o.id for o in queryset if o.editable_by(user) and o.id in valid_ids] elif request.method == 'DELETE': - queryset = [o for o in queryset if o.deletable_by(user)] + valid_ids = [o.id for o in queryset if o.deletable_by(user) and o.id in valid_ids] elif request.method == 'POST': # no access control required for POST requests pass - - return queryset + return queryset.filter(id__in=valid_ids) diff --git a/uvdat/core/rest/map_layers.py b/uvdat/core/rest/map_layers.py index 97cc0a4f..f609a4e4 100644 --- a/uvdat/core/rest/map_layers.py +++ b/uvdat/core/rest/map_layers.py @@ -74,6 +74,7 @@ class RasterMapLayerViewSet(ModelViewSet, LargeImageFileDetailMixin): queryset = RasterMapLayer.objects.select_related('dataset').all() serializer_class = RasterMapLayerSerializer filter_backends = [AccessControl] + lookup_field = "id" FILE_FIELD_NAME = 'cloud_optimized_geotiff' @action( @@ -92,6 +93,7 @@ class VectorMapLayerViewSet(ModelViewSet): queryset = VectorMapLayer.objects.select_related('dataset').all() serializer_class = VectorMapLayerSerializer filter_backends = [AccessControl] + lookup_field = "id" def retrieve(self, request, *args, **kwargs): instance = self.get_object() diff --git a/uvdat/core/rest/network.py b/uvdat/core/rest/network.py index d72b5abf..b449f766 100644 --- a/uvdat/core/rest/network.py +++ b/uvdat/core/rest/network.py @@ -13,15 +13,18 @@ class NetworkViewSet(ModelViewSet): queryset = Network.objects.all() serializer_class = NetworkSerializer filter_backends = [AccessControl] + lookup_field = "id" class NetworkNodeViewSet(ModelViewSet): queryset = NetworkNode.objects.all() serializer_class = NetworkNodeSerializer filter_backends = [AccessControl] + lookup_field = "id" class NetworkEdgeViewSet(ModelViewSet): queryset = NetworkEdge.objects.all() serializer_class = NetworkEdgeSerializer filter_backends = [AccessControl] + lookup_field = "id" diff --git a/uvdat/core/rest/project.py b/uvdat/core/rest/project.py index 902b3407..4f64b74f 100644 --- a/uvdat/core/rest/project.py +++ b/uvdat/core/rest/project.py @@ -12,6 +12,10 @@ class ProjectViewSet(ModelViewSet): queryset = Project.objects.all() serializer_class = ProjectSerializer filter_backends = [AccessControl] + lookup_field = "id" + + def perform_create(self, serializer): + serializer.save(owner=self.request.user) @action(detail=True, methods=['get']) def regions(self, request, **kwargs): diff --git a/uvdat/core/rest/regions.py b/uvdat/core/rest/regions.py index e6d00e68..14bedf28 100644 --- a/uvdat/core/rest/regions.py +++ b/uvdat/core/rest/regions.py @@ -21,12 +21,14 @@ class SourceRegionViewSet(ModelViewSet): queryset = SourceRegion.objects.all() serializer_class = SourceRegionSerializer filter_backends = [AccessControl] + lookup_field = "id" class DerivedRegionViewSet(ModelViewSet): queryset = DerivedRegion.objects.all() serializer_class = DerivedRegionListSerializer filter_backends = [AccessControl] + lookup_field = "id" def get_serializer_class(self): if self.detail: diff --git a/uvdat/core/rest/serializers.py b/uvdat/core/rest/serializers.py index 330ab426..c4786f6b 100644 --- a/uvdat/core/rest/serializers.py +++ b/uvdat/core/rest/serializers.py @@ -1,6 +1,7 @@ import json from django.contrib.auth.models import User +from django.contrib.gis.geos import Point from django.contrib.gis.serializers import geojson from rest_framework import serializers @@ -28,15 +29,22 @@ class Meta: class ProjectSerializer(serializers.ModelSerializer): default_map_center = serializers.SerializerMethodField('get_center') - owner = UserSerializer(allow_null=True) - collaborators = UserSerializer(many=True) - followers = UserSerializer(many=True) + owner = UserSerializer(allow_null=True, required=False) + collaborators = UserSerializer(many=True, required=False) + followers = UserSerializer(many=True, required=False) def get_center(self, obj): # Web client expects Lon, Lat if obj.default_map_center: return [obj.default_map_center.y, obj.default_map_center.x] + def to_internal_value(self, data): + center = data.get('default_map_center') + data = super().to_internal_value(data) + if isinstance(center, list): + data['default_map_center'] = Point(center[1], center[0]) + return data + class Meta: model = Project fields = '__all__' diff --git a/uvdat/core/rest/simulations.py b/uvdat/core/rest/simulations.py index 710039bb..864c4c46 100644 --- a/uvdat/core/rest/simulations.py +++ b/uvdat/core/rest/simulations.py @@ -65,6 +65,7 @@ class SimulationViewSet(ModelViewSet): queryset = SimulationResult.objects.all() serializer_class = uvdat_serializers.SimulationResultSerializer filter_backends = [AccessControl] + lookup_field = "id" @action( detail=False,