Skip to content

Commit

Permalink
fix: address some server-side API bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
annehaley committed Sep 25, 2024
1 parent de74292 commit 6943cb4
Show file tree
Hide file tree
Showing 10 changed files with 32 additions and 12 deletions.
1 change: 1 addition & 0 deletions uvdat/core/rest/chart.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ class ChartViewSet(ModelViewSet):
queryset = Chart.objects.all()
serializer_class = ChartSerializer
filter_backends = [AccessControl]
lookup_field = "id"

def validate_editable(self, chart, func, *args, **kwargs):
if chart.editable:
Expand Down
1 change: 1 addition & 0 deletions uvdat/core/rest/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class DatasetViewSet(ModelViewSet):
queryset = Dataset.objects.all()
serializer_class = DatasetSerializer
filter_backends = [AccessControl]
lookup_field = "id"

@action(detail=True, methods=['get'])
def map_layers(self, request, **kwargs):
Expand Down
1 change: 1 addition & 0 deletions uvdat/core/rest/file_item.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@ class FileItemViewSet(ModelViewSet):
queryset = FileItem.objects.all()
serializer_class = FileItemSerializer
filter_backends = [AccessControl]
lookup_field = "id"
15 changes: 6 additions & 9 deletions uvdat/core/rest/filter.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,20 @@
from rest_framework.filters import BaseFilterBackend

from uvdat.core.models import Project

class AccessControl(BaseFilterBackend):
def filter_queryset(self, request, queryset, view):
project_id = request.query_params.get('project')
user = request.user
valid_ids = [o.id for o in queryset]
if project_id:
project = Project.objects.get(id=project_id)
queryset = [o for o in queryset if o.is_in_project(project)]
valid_ids = [o.id for o in queryset if o.is_in_project(project_id)]

if request.method == 'GET':
queryset = [o for o in queryset if o.readable_by(user)]
valid_ids = [o.id for o in queryset if o.readable_by(user) and o.id in valid_ids]
elif request.method == 'PUT' or request.method == 'PATCH':
queryset = [o for o in queryset if o.editable_by(user)]
valid_ids = [o.id for o in queryset if o.editable_by(user) and o.id in valid_ids]
elif request.method == 'DELETE':
queryset = [o for o in queryset if o.deletable_by(user)]
valid_ids = [o.id for o in queryset if o.deletable_by(user) and o.id in valid_ids]
elif request.method == 'POST':
# no access control required for POST requests
pass

return queryset
return queryset.filter(id__in=valid_ids)
2 changes: 2 additions & 0 deletions uvdat/core/rest/map_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ class RasterMapLayerViewSet(ModelViewSet, LargeImageFileDetailMixin):
queryset = RasterMapLayer.objects.select_related('dataset').all()
serializer_class = RasterMapLayerSerializer
filter_backends = [AccessControl]
lookup_field = "id"
FILE_FIELD_NAME = 'cloud_optimized_geotiff'

@action(
Expand All @@ -92,6 +93,7 @@ class VectorMapLayerViewSet(ModelViewSet):
queryset = VectorMapLayer.objects.select_related('dataset').all()
serializer_class = VectorMapLayerSerializer
filter_backends = [AccessControl]
lookup_field = "id"

def retrieve(self, request, *args, **kwargs):
instance = self.get_object()
Expand Down
3 changes: 3 additions & 0 deletions uvdat/core/rest/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,18 @@ class NetworkViewSet(ModelViewSet):
queryset = Network.objects.all()
serializer_class = NetworkSerializer
filter_backends = [AccessControl]
lookup_field = "id"


class NetworkNodeViewSet(ModelViewSet):
queryset = NetworkNode.objects.all()
serializer_class = NetworkNodeSerializer
filter_backends = [AccessControl]
lookup_field = "id"


class NetworkEdgeViewSet(ModelViewSet):
queryset = NetworkEdge.objects.all()
serializer_class = NetworkEdgeSerializer
filter_backends = [AccessControl]
lookup_field = "id"
4 changes: 4 additions & 0 deletions uvdat/core/rest/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ class ProjectViewSet(ModelViewSet):
queryset = Project.objects.all()
serializer_class = ProjectSerializer
filter_backends = [AccessControl]
lookup_field = "id"

def perform_create(self, serializer):
serializer.save(owner=self.request.user)

@action(detail=True, methods=['get'])
def regions(self, request, **kwargs):
Expand Down
2 changes: 2 additions & 0 deletions uvdat/core/rest/regions.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,14 @@ class SourceRegionViewSet(ModelViewSet):
queryset = SourceRegion.objects.all()
serializer_class = SourceRegionSerializer
filter_backends = [AccessControl]
lookup_field = "id"


class DerivedRegionViewSet(ModelViewSet):
queryset = DerivedRegion.objects.all()
serializer_class = DerivedRegionListSerializer
filter_backends = [AccessControl]
lookup_field = "id"

def get_serializer_class(self):
if self.detail:
Expand Down
14 changes: 11 additions & 3 deletions uvdat/core/rest/serializers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json

from django.contrib.auth.models import User
from django.contrib.gis.geos import Point
from django.contrib.gis.serializers import geojson
from rest_framework import serializers

Expand Down Expand Up @@ -28,15 +29,22 @@ class Meta:

class ProjectSerializer(serializers.ModelSerializer):
default_map_center = serializers.SerializerMethodField('get_center')
owner = UserSerializer(allow_null=True)
collaborators = UserSerializer(many=True)
followers = UserSerializer(many=True)
owner = UserSerializer(allow_null=True, required=False)
collaborators = UserSerializer(many=True, required=False)
followers = UserSerializer(many=True, required=False)

def get_center(self, obj):
# Web client expects Lon, Lat
if obj.default_map_center:
return [obj.default_map_center.y, obj.default_map_center.x]

def to_internal_value(self, data):
center = data.get('default_map_center')
data = super().to_internal_value(data)
if isinstance(center, list):
data['default_map_center'] = Point(center[1], center[0])
return data

class Meta:
model = Project
fields = '__all__'
Expand Down
1 change: 1 addition & 0 deletions uvdat/core/rest/simulations.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ class SimulationViewSet(ModelViewSet):
queryset = SimulationResult.objects.all()
serializer_class = uvdat_serializers.SimulationResultSerializer
filter_backends = [AccessControl]
lookup_field = "id"

@action(
detail=False,
Expand Down

0 comments on commit 6943cb4

Please sign in to comment.