diff --git a/search_indices/serializers/base.py b/search_indices/serializers/base.py index 5c98efa6..25e22f16 100644 --- a/search_indices/serializers/base.py +++ b/search_indices/serializers/base.py @@ -30,5 +30,10 @@ class Meta: def get_score(self, obj: Hit) -> int: return obj.meta.score + @property + def is_authenticated(self): + request = self.context.get("request") + return request and request.user.is_authenticated + def get_attributes(self, obj: Hit) -> Optional[dict]: - return get_attributes(obj, "attributes") + return get_attributes(obj, "attributes", self.is_authenticated) diff --git a/search_indices/serializers/utils.py b/search_indices/serializers/utils.py index eb17d972..4efd9e26 100644 --- a/search_indices/serializers/utils.py +++ b/search_indices/serializers/utils.py @@ -2,8 +2,18 @@ from elasticsearch_dsl.response.hit import Hit +attributes_for_authenticated = ( + "function_InformationSystem", + "action_InformationSystem", + "classification_InformationSystem", + "record_InformationSystem", + "phase_InformationSystem", +) -def get_attributes(obj: Hit, attribute_field_name: str) -> Optional[dict]: + +def get_attributes( + obj: Hit, attribute_field_name: str, authenticated: bool +) -> Optional[dict]: """ Fetch attributes from index and revert the attribute names that have "." replaced with "+". @@ -14,6 +24,11 @@ def get_attributes(obj: Hit, attribute_field_name: str) -> Optional[dict]: attrs = attrs.to_dict() for key, value in attrs.items(): key = key.replace("+", ".") + + if not authenticated and key in attributes_for_authenticated: + continue + attributes[key] = value + return attributes return None diff --git a/search_indices/tests/conftest.py b/search_indices/tests/conftest.py index a9a36ea5..db6d8a62 100644 --- a/search_indices/tests/conftest.py +++ b/search_indices/tests/conftest.py @@ -7,6 +7,7 @@ from elasticsearch import Elasticsearch from elasticsearch_dsl.connections import add_connection from pytest import fixture +from rest_framework.test import APIClient from metarecord.models import Action, Classification, Function, Phase, Record from metarecord.tests.conftest import user, user_api_client # noqa @@ -57,7 +58,7 @@ def destroy_indices(): RecordDocument._index.delete(ignore=[400, 404]) -@fixture(scope="session", autouse=True) +@fixture(autouse=True) def create_indices(): """ Initialize all indices with the custom analyzers. @@ -82,6 +83,11 @@ def es_connection(): yield es_connection +@fixture +def api_client(): + return APIClient() + + @fixture def action(phase): return Action.objects.create( @@ -89,6 +95,13 @@ def action(phase): ) +@fixture +def action_with_information_system(phase): + return Action.objects.create( + attributes={"InformationSystem": "xyz"}, phase=phase, index=1 + ) + + @fixture def action_2(phase_2): return Action.objects.create( @@ -124,6 +137,13 @@ def function(classification): ) +@fixture +def function_with_information_system(classification_2): + return Function.objects.create( + attributes={"InformationSystem": "xyz"}, classification=classification_2 + ) + + @fixture def function_2(classification_2): return Function.objects.create( @@ -139,6 +159,13 @@ def phase(function): ) +@fixture +def phase_with_information_system(function): + return Phase.objects.create( + attributes={"InformationSystem": "xyz"}, function=function, index=1 + ) + + @fixture def phase_2(function_2): return Phase.objects.create( diff --git a/search_indices/tests/test_elastic_api.py b/search_indices/tests/test_elastic_api.py index 93eeae65..43789071 100644 --- a/search_indices/tests/test_elastic_api.py +++ b/search_indices/tests/test_elastic_api.py @@ -1,8 +1,7 @@ import pytest from rest_framework.reverse import reverse -from rest_framework.test import APIClient -from metarecord.models import Record +from metarecord.models import Phase, Record ACTION_LIST_URL = reverse("action_search-list") ALL_LIST_URL = reverse("all_search-list") @@ -45,6 +44,207 @@ def test_classification_search_fuzzy2(user_api_client, classification): assert classification.uuid.hex in uuids +@pytest.mark.django_db +def test_function_information_system_all_search_unauthenticated_no_effect( + api_client, function_with_information_system, phase +): + phase.attributes = {"AdditionalInformation": "xyz"} + phase.save() + + response = api_client.get(ALL_LIST_URL + "?search=xyz") + results = response.data["results"] if "results" in response.data else response.data + uuids = list(result["id"] for result in results) + + assert function_with_information_system.uuid.hex not in uuids + assert phase.uuid.hex in uuids + + +@pytest.mark.django_db +def test_function_information_system_all_search_unauthenticated_not_visible( + api_client, function_with_information_system, phase +): + phase.attributes = {"AdditionalInformation": "testing"} + phase.save() + function_with_information_system.attributes = { + "InformationSystem": "xyz", + "AdditionalInformation": "testing", + } + function_with_information_system.save() + + response = api_client.get(ALL_LIST_URL + "?search=testing") + results = response.data["results"] if "results" in response.data else response.data + uuids = list(result["id"] for result in results) + attributes = list(result["attributes"] for result in results) + + assert function_with_information_system.uuid.hex in uuids + assert phase.uuid.hex in uuids + assert {"function_InformationSystem": "xyz"} not in attributes + + +@pytest.mark.django_db +def test_function_information_system_all_search_authenticated( + user_api_client, function_with_information_system, phase +): + response = user_api_client.get(ALL_LIST_URL + "?search=xyz") + results = response.data["results"] if "results" in response.data else response.data + uuids = list(result["id"] for result in results) + attributes = list(result["attributes"] for result in results) + + assert function_with_information_system.uuid.hex in uuids + assert phase.uuid.hex not in uuids + assert {"function_InformationSystem": "xyz"} in attributes + + +@pytest.mark.django_db +def test_action_information_system_all_search_unauthenticated_no_effect( + api_client, action_with_information_system, function_2 +): + function_2.attributes = {"AdditionalInformation": "xyz"} + function_2.save() + + response = api_client.get(ALL_LIST_URL + "?search=xyz") + results = response.data["results"] if "results" in response.data else response.data + uuids = list(result["id"] for result in results) + + assert action_with_information_system.uuid.hex not in uuids + assert function_2.uuid.hex in uuids + + +@pytest.mark.django_db +def test_action_information_system_all_search_unauthenticated_not_visible( + api_client, action_with_information_system, function_2 +): + function_2.attributes = {"AdditionalInformation": "testing"} + function_2.save() + action_with_information_system.attributes = { + "InformationSystem": "xyz", + "AdditionalInformation": "testing", + } + action_with_information_system.save() + + response = api_client.get(ALL_LIST_URL + "?search=testing") + results = response.data["results"] if "results" in response.data else response.data + uuids = list(result["id"] for result in results) + attributes = list(result["attributes"] for result in results) + + assert action_with_information_system.uuid.hex in uuids + assert function_2.uuid.hex in uuids + assert {"action_InformationSystem": "xyz"} not in attributes + + +@pytest.mark.django_db +def test_action_information_system_all_search_authenticated( + user_api_client, action_with_information_system, function_2 +): + response = user_api_client.get(ALL_LIST_URL + "?search=xyz") + results = response.data["results"] if "results" in response.data else response.data + uuids = list(result["id"] for result in results) + attributes = list(result["attributes"] for result in results) + assert action_with_information_system.uuid.hex in uuids + assert function_2.uuid.hex not in uuids + assert {"action_InformationSystem": "xyz"} in attributes + + +@pytest.mark.django_db +def test_record_information_system_all_search_unauthenticated_no_effect( + api_client, record_with_information_system, function_2 +): + function_2.attributes = {"AdditionalInformation": "xyz"} + function_2.save() + + response = api_client.get(ALL_LIST_URL + "?search=xyz") + results = response.data["results"] if "results" in response.data else response.data + uuids = list(result["id"] for result in results) + + assert record_with_information_system.uuid.hex not in uuids + assert function_2.uuid.hex in uuids + + +@pytest.mark.django_db +def test_record_information_system_all_search_unauthenticated_not_visible( + api_client, record_with_information_system, function_2 +): + function_2.attributes = {"AdditionalInformation": "testing"} + function_2.save() + record_with_information_system.attributes = { + "InformationSystem": "xyz", + "AdditionalInformation": "testing", + } + record_with_information_system.save() + + response = api_client.get(ALL_LIST_URL + "?search=testing") + results = response.data["results"] if "results" in response.data else response.data + uuids = list(result["id"] for result in results) + attributes = list(result["attributes"] for result in results) + + assert record_with_information_system.uuid.hex in uuids + assert function_2.uuid.hex in uuids + assert {"action_InformationSystem": "xyz"} not in attributes + + +@pytest.mark.django_db +def test_record_information_system_all_search_authenticated( + user_api_client, record_with_information_system, function_2 +): + response = user_api_client.get(ALL_LIST_URL + "?search=xyz") + results = response.data["results"] if "results" in response.data else response.data + uuids = list(result["id"] for result in results) + attributes = list(result["attributes"] for result in results) + assert record_with_information_system.uuid.hex in uuids + assert function_2.uuid.hex not in uuids + assert {"record_InformationSystem": "xyz"} in attributes + + +@pytest.mark.django_db +def test_phase_information_system_all_search_unauthenticated_no_effect( + api_client, phase_with_information_system, function_2 +): + function_2.attributes = {"AdditionalInformation": "xyz"} + function_2.save() + + response = api_client.get(ALL_LIST_URL + "?search=xyz") + results = response.data["results"] if "results" in response.data else response.data + uuids = list(result["id"] for result in results) + + assert phase_with_information_system.uuid.hex not in uuids + assert function_2.uuid.hex in uuids + + +@pytest.mark.django_db +def test_phase_information_system_all_search_unauthenticated_not_visible( + api_client, phase_with_information_system, function_2 +): + function_2.attributes = {"AdditionalInformation": "testing"} + function_2.save() + phase_with_information_system.attributes = { + "InformationSystem": "xyz", + "AdditionalInformation": "testing", + } + phase_with_information_system.save() + + response = api_client.get(ALL_LIST_URL + "?search=testing") + results = response.data["results"] if "results" in response.data else response.data + uuids = list(result["id"] for result in results) + attributes = list(result["attributes"] for result in results) + + assert phase_with_information_system.uuid.hex in uuids + assert function_2.uuid.hex in uuids + assert {"action_InformationSystem": "xyz"} not in attributes + + +@pytest.mark.django_db +def test_phase_information_system_all_search_authenticated( + user_api_client, phase_with_information_system, function_2 +): + response = user_api_client.get(ALL_LIST_URL + "?search=xyz") + results = response.data["results"] if "results" in response.data else response.data + uuids = list(result["id"] for result in results) + attributes = list(result["attributes"] for result in results) + assert phase_with_information_system.uuid.hex in uuids + assert function_2.uuid.hex not in uuids + assert {"phase_InformationSystem": "xyz"} in attributes + + @pytest.mark.django_db def test_classification_search_query_string(user_api_client, classification_2): url = ALL_LIST_URL + '?search_simple_query_string="testisana ja toinen testisana"' @@ -67,6 +267,35 @@ def test_action_filter_attribute_exact(user_api_client, action): assert action.uuid.hex in uuids +@pytest.mark.django_db +def test_action_information_system_attribute_for_authenticated( + user_api_client, action_with_information_system +): + response = user_api_client.get(ACTION_LIST_URL) + results = response.data["results"] if "results" in response.data else response.data + uuids = list(result["id"] for result in results) + attributes = list(result["attributes"] for result in results) + + assert response.status_code == 200 + assert action_with_information_system.uuid.hex in uuids + assert {"action_InformationSystem": "xyz"} in attributes + + +@pytest.mark.django_db +def test_action_does_not_show_information_system_attribute_for_unauthenticated( + api_client, action_with_information_system +): + response = api_client.get(ACTION_LIST_URL) + + results = response.data["results"] if "results" in response.data else response.data + uuids = list(result["id"] for result in results) + attributes = list(result["attributes"] for result in results) + + assert response.status_code == 200 + assert action_with_information_system.uuid.hex in uuids + assert {"action_InformationSystem": "xyz"} not in attributes + + @pytest.mark.django_db def test_classification_filter_title_exact(user_api_client, classification): url = CLASSIFICATION_LIST_URL + "?title=testisana" @@ -131,14 +360,44 @@ def test_record_filter_information_system_attribute_exact_filters_for_authentica assert record_2.uuid.hex not in uuids +@pytest.mark.django_db +def test_record_does_shows_information_system_attribute_for_authenticated( + user_api_client, record_with_information_system, record_2 +): + response = user_api_client.get(RECORD_LIST_URL) + results = response.data["results"] if "results" in response.data else response.data + uuids = list(result["id"] for result in results) + attributes = list(result["attributes"] for result in results) + + assert response.status_code == 200 + assert record_with_information_system.uuid.hex in uuids + assert {"record_InformationSystem": "xyz"} in attributes + assert record_2.uuid.hex in uuids + + +@pytest.mark.django_db +def test_record_does_not_show_information_system_attribute_for_unauthenticated( + api_client, record_with_information_system, record_2 +): + response = api_client.get(RECORD_LIST_URL) + + results = response.data["results"] if "results" in response.data else response.data + uuids = list(result["id"] for result in results) + attributes = list(result["attributes"] for result in results) + + assert response.status_code == 200 + assert record_with_information_system.uuid.hex in uuids + assert {"record_InformationSystem": "xyz"} not in attributes + assert record_2.uuid.hex in uuids + + @pytest.mark.django_db def test_record_filter_information_system_attribute_exact_does_not_filter_for_unauthenticated( - record_with_information_system, record_2 + api_client, record_with_information_system, record_2 ): assert Record.objects.count() == 2 url = RECORD_LIST_URL + "?record_InformationSystem=xyz" - api_client = APIClient() response = api_client.get(url) results = response.data["results"] if "results" in response.data else response.data @@ -146,3 +405,34 @@ def test_record_filter_information_system_attribute_exact_does_not_filter_for_un assert response.status_code == 200 assert record_with_information_system.uuid.hex in uuids assert record_2.uuid.hex in uuids + + +@pytest.mark.django_db +def test_phase_filter_information_system_attribute_exact_for_authenticated( + user_api_client, phase_with_information_system, phase_2 +): + assert Phase.objects.count() == 2 + url = PHASE_LIST_URL + "?phase_InformationSystem=xyz" + response = user_api_client.get(url) + + results = response.data["results"] if "results" in response.data else response.data + uuids = list(result["id"] for result in results) + assert response.status_code == 200 + assert phase_with_information_system.uuid.hex in uuids + assert phase_2.uuid.hex not in uuids + + +@pytest.mark.django_db +def test_phase_filter_information_system_attribute_exact_does_not_filter_for_unauthenticated( + api_client, phase_with_information_system, phase_2 +): + assert Phase.objects.count() == 2 + + url = PHASE_LIST_URL + "?phase_InformationSystem=xyz" + response = api_client.get(url) + + results = response.data["results"] if "results" in response.data else response.data + uuids = list(result["id"] for result in results) + assert response.status_code == 200 + assert phase_with_information_system.uuid.hex in uuids + assert phase_2.uuid.hex in uuids diff --git a/search_indices/views/base.py b/search_indices/views/base.py index aa1d5f52..bfe69761 100644 --- a/search_indices/views/base.py +++ b/search_indices/views/base.py @@ -8,6 +8,7 @@ from metarecord.pagination import ESRecordPagination from search_indices.backends.faceted_attribute_backend import FacetedAttributeBackend +from search_indices.serializers.utils import attributes_for_authenticated from search_indices.views.utils import populate_filter_fields_with_attributes @@ -68,10 +69,23 @@ class BaseSearchDocumentViewSet(BaseDocumentViewSet): "_score", ) - def filter_queryset(self, queryset): + def _filter_search_fields_for_unauthenticated(self): + search_fields = [] + for field in self.search_fields: + if "InformationSystem" in field: + continue + search_fields.append(field) + self.search_fields = tuple(search_fields) + + for attribute in attributes_for_authenticated: + self.filter_fields.pop(attribute, None) + + def initial(self, request, *args, **kwargs): + if request.user.is_authenticated: + return super().initial(request, *args, **kwargs) + # Restrict querying information system queries to authenticated users. # The information system field contents are not public. - if not self.request.user.is_authenticated: - self.filter_fields.pop("record_InformationSystem", None) + self._filter_search_fields_for_unauthenticated() - return super().filter_queryset(queryset) + super().initial(request, *args, **kwargs)