From 1754ac318a2c2a025ed075af9df5cfc6ad72b9d4 Mon Sep 17 00:00:00 2001 From: Anne Haley Date: Mon, 12 Aug 2024 15:20:06 +0000 Subject: [PATCH] feat: Update API to use AccessControl filter backend --- uvdat/core/rest/chart.py | 15 +++++---------- uvdat/core/rest/dataset.py | 28 +++++++++++++++------------- uvdat/core/rest/file_item.py | 2 ++ uvdat/core/rest/filter.py | 23 +++++++++++++++++++++++ uvdat/core/rest/map_layers.py | 3 +++ uvdat/core/rest/network.py | 4 ++++ uvdat/core/rest/project.py | 2 ++ uvdat/core/rest/regions.py | 9 ++++++--- uvdat/core/rest/simulations.py | 7 +++++-- 9 files changed, 65 insertions(+), 28 deletions(-) create mode 100644 uvdat/core/rest/filter.py diff --git a/uvdat/core/rest/chart.py b/uvdat/core/rest/chart.py index 55371d6a..20f6818a 100644 --- a/uvdat/core/rest/chart.py +++ b/uvdat/core/rest/chart.py @@ -1,21 +1,16 @@ from django.http import HttpResponse from rest_framework.decorators import action -from rest_framework.viewsets import GenericViewSet, mixins +from rest_framework.viewsets import ModelViewSet from uvdat.core.models import Chart +from uvdat.core.rest.serializers import ChartSerializer +from uvdat.core.rest.filter import AccessControl -from .serializers import ChartSerializer - -class ChartViewSet(GenericViewSet, mixins.ListModelMixin): +class ChartViewSet(ModelViewSet): queryset = Chart.objects.all() serializer_class = ChartSerializer - - def get_queryset(self, **kwargs): - context_id = self.request.query_params.get('context') - if context_id: - return Chart.objects.filter(context__id=context_id) - return Chart.objects.all() + filter_backends = [AccessControl] def validate_editable(self, chart, func, *args, **kwargs): if chart.editable: diff --git a/uvdat/core/rest/dataset.py b/uvdat/core/rest/dataset.py index 85279d5d..1e6f88c9 100644 --- a/uvdat/core/rest/dataset.py +++ b/uvdat/core/rest/dataset.py @@ -6,19 +6,21 @@ from rest_framework.viewsets import ModelViewSet from uvdat.core.models import Dataset, NetworkEdge, NetworkNode -from uvdat.core.rest import serializers as uvdat_serializers +from uvdat.core.rest.serializers import ( + DatasetSerializer, + RasterMapLayerSerializer, + VectorMapLayerSerializer, + NetworkEdgeSerializer, + NetworkNodeSerializer, +) +from uvdat.core.rest.filter import AccessControl from uvdat.core.tasks.chart import add_gcc_chart_datum class DatasetViewSet(ModelViewSet): - serializer_class = uvdat_serializers.DatasetSerializer - - def get_queryset(self): - project_id = self.request.query_params.get('project') - if project_id: - return Dataset.objects.filter(project__id=project_id) - else: - return Dataset.objects.all() + queryset = Dataset.objects.all() + serializer_class = DatasetSerializer + filter_backends = [AccessControl] @action(detail=True, methods=['get']) def map_layers(self, request, **kwargs): @@ -27,10 +29,10 @@ def map_layers(self, request, **kwargs): # Set serializer based on dataset type if dataset.dataset_type == Dataset.DatasetType.RASTER: - serializer = uvdat_serializers.RasterMapLayerSerializer(map_layers, many=True) + serializer = RasterMapLayerSerializer(map_layers, many=True) elif dataset.dataset_type == Dataset.DatasetType.VECTOR: # Set serializer - serializer = uvdat_serializers.VectorMapLayerSerializer(map_layers, many=True) + serializer = VectorMapLayerSerializer(map_layers, many=True) else: raise NotImplementedError(f'Dataset Type {dataset.dataset_type}') @@ -51,11 +53,11 @@ def network(self, request, **kwargs): networks.append( { 'nodes': [ - uvdat_serializers.NetworkNodeSerializer(n).data + NetworkNodeSerializer(n).data for n in NetworkNode.objects.filter(network=network) ], 'edges': [ - uvdat_serializers.NetworkEdgeSerializer(e).data + NetworkEdgeSerializer(e).data for e in NetworkEdge.objects.filter(network=network) ], } diff --git a/uvdat/core/rest/file_item.py b/uvdat/core/rest/file_item.py index 1e8e8154..8233a8c4 100644 --- a/uvdat/core/rest/file_item.py +++ b/uvdat/core/rest/file_item.py @@ -2,8 +2,10 @@ from uvdat.core.models import FileItem from uvdat.core.rest.serializers import FileItemSerializer +from uvdat.core.rest.filter import AccessControl class FileItemViewSet(ModelViewSet): queryset = FileItem.objects.all() serializer_class = FileItemSerializer + filter_backends = [AccessControl] diff --git a/uvdat/core/rest/filter.py b/uvdat/core/rest/filter.py new file mode 100644 index 00000000..1e179a8e --- /dev/null +++ b/uvdat/core/rest/filter.py @@ -0,0 +1,23 @@ +from rest_framework.filters import BaseFilterBackend + +from uvdat.core.models import Project + +class AccessControl(BaseFilterBackend): + def filter_queryset(self, request, queryset, view): + project_id = request.query_params.get('project') + user = request.user + if project_id: + project = Project.objects.get(id=project_id) + queryset = [o for o in queryset if o.is_in_project(project)] + + if request.method == 'GET': + queryset = [o for o in queryset if o.readable_by(user)] + elif request.method == 'PUT' or request.method == 'PATCH': + queryset = [o for o in queryset if o.editable_by(user)] + elif request.method == 'DELETE': + queryset = [o for o in queryset if o.deletable_by(user)] + elif request.method == 'POST': + # no access control required for POST requests + pass + + return queryset diff --git a/uvdat/core/rest/map_layers.py b/uvdat/core/rest/map_layers.py index fb7e944f..97cc0a4f 100644 --- a/uvdat/core/rest/map_layers.py +++ b/uvdat/core/rest/map_layers.py @@ -8,6 +8,7 @@ from rest_framework.viewsets import ModelViewSet from uvdat.core.models import RasterMapLayer, VectorMapLayer +from uvdat.core.rest.filter import AccessControl from uvdat.core.rest.serializers import ( RasterMapLayerSerializer, VectorMapLayerDetailSerializer, @@ -72,6 +73,7 @@ class RasterMapLayerViewSet(ModelViewSet, LargeImageFileDetailMixin): queryset = RasterMapLayer.objects.select_related('dataset').all() serializer_class = RasterMapLayerSerializer + filter_backends = [AccessControl] FILE_FIELD_NAME = 'cloud_optimized_geotiff' @action( @@ -89,6 +91,7 @@ def get_raster_data(self, request, resolution: str = '1', **kwargs): class VectorMapLayerViewSet(ModelViewSet): queryset = VectorMapLayer.objects.select_related('dataset').all() serializer_class = VectorMapLayerSerializer + filter_backends = [AccessControl] def retrieve(self, request, *args, **kwargs): instance = self.get_object() diff --git a/uvdat/core/rest/network.py b/uvdat/core/rest/network.py index 3e42cbae..d72b5abf 100644 --- a/uvdat/core/rest/network.py +++ b/uvdat/core/rest/network.py @@ -1,6 +1,7 @@ from rest_framework.viewsets import ModelViewSet from uvdat.core.models import Network, NetworkEdge, NetworkNode +from uvdat.core.rest.filter import AccessControl from uvdat.core.rest.serializers import ( NetworkEdgeSerializer, NetworkNodeSerializer, @@ -11,13 +12,16 @@ class NetworkViewSet(ModelViewSet): queryset = Network.objects.all() serializer_class = NetworkSerializer + filter_backends = [AccessControl] class NetworkNodeViewSet(ModelViewSet): queryset = NetworkNode.objects.all() serializer_class = NetworkNodeSerializer + filter_backends = [AccessControl] class NetworkEdgeViewSet(ModelViewSet): queryset = NetworkEdge.objects.all() serializer_class = NetworkEdgeSerializer + filter_backends = [AccessControl] diff --git a/uvdat/core/rest/project.py b/uvdat/core/rest/project.py index 49d5552b..902b3407 100644 --- a/uvdat/core/rest/project.py +++ b/uvdat/core/rest/project.py @@ -5,11 +5,13 @@ from uvdat.core.models import Project from uvdat.core.rest.serializers import ProjectSerializer from uvdat.core.tasks.osmnx import load_roads +from uvdat.core.rest.filter import AccessControl class ProjectViewSet(ModelViewSet): queryset = Project.objects.all() serializer_class = ProjectSerializer + filter_backends = [AccessControl] @action(detail=True, methods=['get']) def regions(self, request, **kwargs): diff --git a/uvdat/core/rest/regions.py b/uvdat/core/rest/regions.py index 98a8f618..e6d00e68 100644 --- a/uvdat/core/rest/regions.py +++ b/uvdat/core/rest/regions.py @@ -3,10 +3,11 @@ from django.http import HttpResponse from drf_yasg.utils import swagger_auto_schema from rest_framework.decorators import action -from rest_framework.viewsets import GenericViewSet, mixins +from rest_framework.viewsets import ModelViewSet from uvdat.core.models import DerivedRegion, SourceRegion from uvdat.core.tasks.regions import DerivedRegionCreationError, create_derived_region +from uvdat.core.rest.filter import AccessControl from .serializers import ( DerivedRegionCreationSerializer, @@ -16,14 +17,16 @@ ) -class SourceRegionViewSet(mixins.RetrieveModelMixin, mixins.ListModelMixin, GenericViewSet): +class SourceRegionViewSet(ModelViewSet): queryset = SourceRegion.objects.all() serializer_class = SourceRegionSerializer + filter_backends = [AccessControl] -class DerivedRegionViewSet(mixins.RetrieveModelMixin, mixins.ListModelMixin, GenericViewSet): +class DerivedRegionViewSet(ModelViewSet): queryset = DerivedRegion.objects.all() serializer_class = DerivedRegionListSerializer + filter_backends = [AccessControl] def get_serializer_class(self): if self.detail: diff --git a/uvdat/core/rest/simulations.py b/uvdat/core/rest/simulations.py index 11853761..710039bb 100644 --- a/uvdat/core/rest/simulations.py +++ b/uvdat/core/rest/simulations.py @@ -5,10 +5,11 @@ from django.http import HttpResponse from rest_framework.decorators import action from rest_framework.serializers import ModelSerializer -from rest_framework.viewsets import GenericViewSet +from rest_framework.viewsets import ModelViewSet from uvdat.core.models import Project from uvdat.core.models.simulations import AVAILABLE_SIMULATIONS, SimulationResult +from uvdat.core.rest.filter import AccessControl import uvdat.core.rest.serializers as uvdat_serializers @@ -60,8 +61,10 @@ def get_available_simulations(project_id: int): return sims -class SimulationViewSet(GenericViewSet): +class SimulationViewSet(ModelViewSet): + queryset = SimulationResult.objects.all() serializer_class = uvdat_serializers.SimulationResultSerializer + filter_backends = [AccessControl] @action( detail=False,