diff --git a/metarecord/tests/test_api.py b/metarecord/tests/test_api.py index 2a9af752..8a0118ab 100644 --- a/metarecord/tests/test_api.py +++ b/metarecord/tests/test_api.py @@ -2198,7 +2198,7 @@ def test_function_classification_code_filtering( @pytest.mark.django_db def test_function_information_system_filtering( - api_client, user_api_client, classification, classification_2 + user_api_client, classification, classification_2 ): third_classification = Classification.objects.create( title="testification", @@ -2250,13 +2250,50 @@ def test_function_information_system_filtering( assert function.uuid.hex != function_2.uuid.hex assert function.uuid.hex != function_3.uuid.hex - response = api_client.get(FUNCTION_LIST_URL + "?information_system=xyz") + response = user_api_client.get(FUNCTION_LIST_URL + "?information_system=xyz") assert response.status_code == 200 results = response.data["results"] assert len(results) == 1 assert results[0]["id"] == function.uuid.hex +@pytest.mark.django_db +def test_function_information_system_filtering_for_unauthenticated_user( + api_client, classification, classification_2 +): + """Filtering by information system as an unauthenticated user should not have any effect""" + third_classification = Classification.objects.create( + title="testification", + code="00 100", + state=Classification.APPROVED, + function_allowed=True, + ) + function = Function.objects.create( + classification=classification, state=Function.APPROVED + ) + Function.objects.create(classification=classification_2, state=Function.APPROVED) + Function.objects.create( + classification=third_classification, state=Function.APPROVED + ) + + phase = Phase.objects.create( + attributes={"TypeSpecifier": "test phase"}, function=function, index=1 + ) + + action = Action.objects.create( + attributes={"TypeSpecifier": "test action"}, phase=phase, index=1 + ) + + Record.objects.create( + attributes={"InformationSystem": "xyz"}, action=action, index=1 + ) + + response = api_client.get(FUNCTION_LIST_URL + "?information_system=xyz") + assert response.status_code == 200 + results = response.data["results"] + assert len(results) == 3 + + @pytest.mark.django_db def test_function_detail_shows_record_information_system_for_authenticated_user( user_api_client, classification diff --git a/metarecord/views/function.py b/metarecord/views/function.py index db9f5d62..bc4d1775 100644 --- a/metarecord/views/function.py +++ b/metarecord/views/function.py @@ -418,6 +418,15 @@ class Meta: lookup_expr="icontains", ) + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # Restrict querying information system queries to authenticated users. + # The information system field contents are not public. + user = getattr(self.request, "user", None) + if not user or not user.is_authenticated: + self.filters.pop("information_system", None) + def filter_valid_at(self, queryset, name, value): # if neither date is set the function is considered not valid queryset = queryset.exclude( diff --git a/search_indices/tests/conftest.py b/search_indices/tests/conftest.py index a1836a19..a9a36ea5 100644 --- a/search_indices/tests/conftest.py +++ b/search_indices/tests/conftest.py @@ -89,6 +89,13 @@ def action(phase): ) +@fixture +def action_2(phase_2): + return Action.objects.create( + attributes={"AdditionalInformation": "testisana"}, phase=phase_2, index=1 + ) + + @fixture def classification(): return Classification.objects.create( @@ -100,7 +107,7 @@ def classification(): @fixture -def classification2(): +def classification_2(): return Classification.objects.create( title="testisana ja toinen testisana", code="00 00", @@ -117,6 +124,14 @@ def function(classification): ) +@fixture +def function_2(classification_2): + return Function.objects.create( + attributes={"AdditionalInformation": "testword"}, + classification=classification_2, + ) + + @fixture def phase(function): return Phase.objects.create( @@ -124,8 +139,29 @@ def phase(function): ) +@fixture +def phase_2(function_2): + return Phase.objects.create( + attributes={"AdditionalInformation": "testword"}, function=function_2, index=1 + ) + + @fixture def record(action): return Record.objects.create( attributes={"AdditionalInformation": "testisana"}, action=action, index=1 ) + + +@fixture +def record_with_information_system(action): + return Record.objects.create( + attributes={"InformationSystem": "xyz"}, action=action, index=1 + ) + + +@fixture +def record_2(action_2): + return Record.objects.create( + attributes={"AdditionalInformation": "testword"}, action=action_2, index=1 + ) diff --git a/search_indices/tests/test_elastic_api.py b/search_indices/tests/test_elastic_api.py index 9c0c0624..93eeae65 100644 --- a/search_indices/tests/test_elastic_api.py +++ b/search_indices/tests/test_elastic_api.py @@ -1,5 +1,8 @@ import pytest from rest_framework.reverse import reverse +from rest_framework.test import APIClient + +from metarecord.models import Record ACTION_LIST_URL = reverse("action_search-list") ALL_LIST_URL = reverse("all_search-list") @@ -43,14 +46,14 @@ def test_classification_search_fuzzy2(user_api_client, classification): @pytest.mark.django_db -def test_classification_search_query_string(user_api_client, classification2): +def test_classification_search_query_string(user_api_client, classification_2): url = ALL_LIST_URL + '?search_simple_query_string="testisana ja toinen testisana"' response = user_api_client.get(url) assert response.status_code == 200 results = response.data["results"] if "results" in response.data else response.data uuids = list(result["id"] for result in results) - assert classification2.uuid.hex in uuids + assert classification_2.uuid.hex in uuids @pytest.mark.django_db @@ -98,11 +101,48 @@ def test_phase_filter_attribute_exact(user_api_client, phase): @pytest.mark.django_db -def test_record_filter_attribute_exact(user_api_client, record): +def test_record_filter_attribute_exact(user_api_client, record, record_2): + assert Record.objects.count() == 2 + url = RECORD_LIST_URL + "?record_AdditionalInformation=testisana" response = user_api_client.get(url) assert response.status_code == 200 results = response.data["results"] if "results" in response.data else response.data uuids = list(result["id"] for result in results) + assert len(results) == 1 assert record.uuid.hex in uuids + assert record_2.uuid.hex not in uuids + + +@pytest.mark.django_db +def test_record_filter_information_system_attribute_exact_filters_for_authenticated( + user_api_client, record_with_information_system, record_2 +): + assert Record.objects.count() == 2 + + url = RECORD_LIST_URL + "?record_InformationSystem=xyz" + response = user_api_client.get(url) + assert response.status_code == 200 + + results = response.data["results"] if "results" in response.data else response.data + uuids = list(result["id"] for result in results) + assert record_with_information_system.uuid.hex in uuids + assert record_2.uuid.hex not in uuids + + +@pytest.mark.django_db +def test_record_filter_information_system_attribute_exact_does_not_filter_for_unauthenticated( + record_with_information_system, record_2 +): + assert Record.objects.count() == 2 + + url = RECORD_LIST_URL + "?record_InformationSystem=xyz" + api_client = APIClient() + response = api_client.get(url) + + results = response.data["results"] if "results" in response.data else response.data + uuids = list(result["id"] for result in results) + assert response.status_code == 200 + assert record_with_information_system.uuid.hex in uuids + assert record_2.uuid.hex in uuids diff --git a/search_indices/views/base.py b/search_indices/views/base.py index c69bd28e..aa1d5f52 100644 --- a/search_indices/views/base.py +++ b/search_indices/views/base.py @@ -67,3 +67,11 @@ class BaseSearchDocumentViewSet(BaseDocumentViewSet): "type", "_score", ) + + def filter_queryset(self, queryset): + # Restrict querying information system queries to authenticated users. + # The information system field contents are not public. + if not self.request.user.is_authenticated: + self.filter_fields.pop("record_InformationSystem", None) + + return super().filter_queryset(queryset)