Skip to content

Commit

Permalink
feat: Update API to use AccessControl filter backend
Browse files Browse the repository at this point in the history
  • Loading branch information
annehaley committed Sep 25, 2024
1 parent 245a04f commit 1754ac3
Show file tree
Hide file tree
Showing 9 changed files with 65 additions and 28 deletions.
15 changes: 5 additions & 10 deletions uvdat/core/rest/chart.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,16 @@
from django.http import HttpResponse
from rest_framework.decorators import action
from rest_framework.viewsets import GenericViewSet, mixins
from rest_framework.viewsets import ModelViewSet

from uvdat.core.models import Chart
from uvdat.core.rest.serializers import ChartSerializer
from uvdat.core.rest.filter import AccessControl

from .serializers import ChartSerializer


class ChartViewSet(GenericViewSet, mixins.ListModelMixin):
class ChartViewSet(ModelViewSet):
queryset = Chart.objects.all()
serializer_class = ChartSerializer

def get_queryset(self, **kwargs):
context_id = self.request.query_params.get('context')
if context_id:
return Chart.objects.filter(context__id=context_id)
return Chart.objects.all()
filter_backends = [AccessControl]

def validate_editable(self, chart, func, *args, **kwargs):
if chart.editable:
Expand Down
28 changes: 15 additions & 13 deletions uvdat/core/rest/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,21 @@
from rest_framework.viewsets import ModelViewSet

from uvdat.core.models import Dataset, NetworkEdge, NetworkNode
from uvdat.core.rest import serializers as uvdat_serializers
from uvdat.core.rest.serializers import (
DatasetSerializer,
RasterMapLayerSerializer,
VectorMapLayerSerializer,
NetworkEdgeSerializer,
NetworkNodeSerializer,
)
from uvdat.core.rest.filter import AccessControl
from uvdat.core.tasks.chart import add_gcc_chart_datum


class DatasetViewSet(ModelViewSet):
serializer_class = uvdat_serializers.DatasetSerializer

def get_queryset(self):
project_id = self.request.query_params.get('project')
if project_id:
return Dataset.objects.filter(project__id=project_id)
else:
return Dataset.objects.all()
queryset = Dataset.objects.all()
serializer_class = DatasetSerializer
filter_backends = [AccessControl]

@action(detail=True, methods=['get'])
def map_layers(self, request, **kwargs):
Expand All @@ -27,10 +29,10 @@ def map_layers(self, request, **kwargs):

# Set serializer based on dataset type
if dataset.dataset_type == Dataset.DatasetType.RASTER:
serializer = uvdat_serializers.RasterMapLayerSerializer(map_layers, many=True)
serializer = RasterMapLayerSerializer(map_layers, many=True)
elif dataset.dataset_type == Dataset.DatasetType.VECTOR:
# Set serializer
serializer = uvdat_serializers.VectorMapLayerSerializer(map_layers, many=True)
serializer = VectorMapLayerSerializer(map_layers, many=True)
else:
raise NotImplementedError(f'Dataset Type {dataset.dataset_type}')

Expand All @@ -51,11 +53,11 @@ def network(self, request, **kwargs):
networks.append(
{
'nodes': [
uvdat_serializers.NetworkNodeSerializer(n).data
NetworkNodeSerializer(n).data
for n in NetworkNode.objects.filter(network=network)
],
'edges': [
uvdat_serializers.NetworkEdgeSerializer(e).data
NetworkEdgeSerializer(e).data
for e in NetworkEdge.objects.filter(network=network)
],
}
Expand Down
2 changes: 2 additions & 0 deletions uvdat/core/rest/file_item.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@

from uvdat.core.models import FileItem
from uvdat.core.rest.serializers import FileItemSerializer
from uvdat.core.rest.filter import AccessControl


class FileItemViewSet(ModelViewSet):
queryset = FileItem.objects.all()
serializer_class = FileItemSerializer
filter_backends = [AccessControl]
23 changes: 23 additions & 0 deletions uvdat/core/rest/filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from rest_framework.filters import BaseFilterBackend

from uvdat.core.models import Project

class AccessControl(BaseFilterBackend):
def filter_queryset(self, request, queryset, view):
project_id = request.query_params.get('project')
user = request.user
if project_id:
project = Project.objects.get(id=project_id)
queryset = [o for o in queryset if o.is_in_project(project)]

if request.method == 'GET':
queryset = [o for o in queryset if o.readable_by(user)]
elif request.method == 'PUT' or request.method == 'PATCH':
queryset = [o for o in queryset if o.editable_by(user)]
elif request.method == 'DELETE':
queryset = [o for o in queryset if o.deletable_by(user)]
elif request.method == 'POST':
# no access control required for POST requests
pass

return queryset
3 changes: 3 additions & 0 deletions uvdat/core/rest/map_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from rest_framework.viewsets import ModelViewSet

from uvdat.core.models import RasterMapLayer, VectorMapLayer
from uvdat.core.rest.filter import AccessControl
from uvdat.core.rest.serializers import (
RasterMapLayerSerializer,
VectorMapLayerDetailSerializer,
Expand Down Expand Up @@ -72,6 +73,7 @@
class RasterMapLayerViewSet(ModelViewSet, LargeImageFileDetailMixin):
queryset = RasterMapLayer.objects.select_related('dataset').all()
serializer_class = RasterMapLayerSerializer
filter_backends = [AccessControl]
FILE_FIELD_NAME = 'cloud_optimized_geotiff'

@action(
Expand All @@ -89,6 +91,7 @@ def get_raster_data(self, request, resolution: str = '1', **kwargs):
class VectorMapLayerViewSet(ModelViewSet):
queryset = VectorMapLayer.objects.select_related('dataset').all()
serializer_class = VectorMapLayerSerializer
filter_backends = [AccessControl]

def retrieve(self, request, *args, **kwargs):
instance = self.get_object()
Expand Down
4 changes: 4 additions & 0 deletions uvdat/core/rest/network.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from rest_framework.viewsets import ModelViewSet

from uvdat.core.models import Network, NetworkEdge, NetworkNode
from uvdat.core.rest.filter import AccessControl
from uvdat.core.rest.serializers import (
NetworkEdgeSerializer,
NetworkNodeSerializer,
Expand All @@ -11,13 +12,16 @@
class NetworkViewSet(ModelViewSet):
queryset = Network.objects.all()
serializer_class = NetworkSerializer
filter_backends = [AccessControl]


class NetworkNodeViewSet(ModelViewSet):
queryset = NetworkNode.objects.all()
serializer_class = NetworkNodeSerializer
filter_backends = [AccessControl]


class NetworkEdgeViewSet(ModelViewSet):
queryset = NetworkEdge.objects.all()
serializer_class = NetworkEdgeSerializer
filter_backends = [AccessControl]
2 changes: 2 additions & 0 deletions uvdat/core/rest/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
from uvdat.core.models import Project
from uvdat.core.rest.serializers import ProjectSerializer
from uvdat.core.tasks.osmnx import load_roads
from uvdat.core.rest.filter import AccessControl


class ProjectViewSet(ModelViewSet):
queryset = Project.objects.all()
serializer_class = ProjectSerializer
filter_backends = [AccessControl]

@action(detail=True, methods=['get'])
def regions(self, request, **kwargs):
Expand Down
9 changes: 6 additions & 3 deletions uvdat/core/rest/regions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
from django.http import HttpResponse
from drf_yasg.utils import swagger_auto_schema
from rest_framework.decorators import action
from rest_framework.viewsets import GenericViewSet, mixins
from rest_framework.viewsets import ModelViewSet

from uvdat.core.models import DerivedRegion, SourceRegion
from uvdat.core.tasks.regions import DerivedRegionCreationError, create_derived_region
from uvdat.core.rest.filter import AccessControl

from .serializers import (
DerivedRegionCreationSerializer,
Expand All @@ -16,14 +17,16 @@
)


class SourceRegionViewSet(mixins.RetrieveModelMixin, mixins.ListModelMixin, GenericViewSet):
class SourceRegionViewSet(ModelViewSet):
queryset = SourceRegion.objects.all()
serializer_class = SourceRegionSerializer
filter_backends = [AccessControl]


class DerivedRegionViewSet(mixins.RetrieveModelMixin, mixins.ListModelMixin, GenericViewSet):
class DerivedRegionViewSet(ModelViewSet):
queryset = DerivedRegion.objects.all()
serializer_class = DerivedRegionListSerializer
filter_backends = [AccessControl]

def get_serializer_class(self):
if self.detail:
Expand Down
7 changes: 5 additions & 2 deletions uvdat/core/rest/simulations.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
from django.http import HttpResponse
from rest_framework.decorators import action
from rest_framework.serializers import ModelSerializer
from rest_framework.viewsets import GenericViewSet
from rest_framework.viewsets import ModelViewSet

from uvdat.core.models import Project
from uvdat.core.models.simulations import AVAILABLE_SIMULATIONS, SimulationResult
from uvdat.core.rest.filter import AccessControl
import uvdat.core.rest.serializers as uvdat_serializers


Expand Down Expand Up @@ -60,8 +61,10 @@ def get_available_simulations(project_id: int):
return sims


class SimulationViewSet(GenericViewSet):
class SimulationViewSet(ModelViewSet):
queryset = SimulationResult.objects.all()
serializer_class = uvdat_serializers.SimulationResultSerializer
filter_backends = [AccessControl]

@action(
detail=False,
Expand Down

0 comments on commit 1754ac3

Please sign in to comment.