diff --git a/graphene_django/compat.py b/graphene_django/compat.py index 8a2b93369..537fd1da3 100644 --- a/graphene_django/compat.py +++ b/graphene_django/compat.py @@ -6,13 +6,16 @@ class MissingType(object): # Postgres fields are only available in Django with psycopg2 installed # and we cannot have psycopg2 on PyPy from django.contrib.postgres.fields import ( + IntegerRangeField, ArrayField, HStoreField, JSONField as PGJSONField, RangeField, ) except ImportError: - ArrayField, HStoreField, PGJSONField, RangeField = (MissingType,) * 4 + IntegerRangeField, ArrayField, HStoreField, PGJSONField, RangeField = ( + MissingType, + ) * 5 try: # JSONField is only available from Django 3.1 diff --git a/graphene_django/filter/tests/conftest.py b/graphene_django/filter/tests/conftest.py new file mode 100644 index 000000000..031364519 --- /dev/null +++ b/graphene_django/filter/tests/conftest.py @@ -0,0 +1,128 @@ +from mock import MagicMock +import pytest + +from django.db import models +from django.db.models.query import QuerySet +from django_filters import filters +from django_filters import FilterSet +import graphene +from graphene.relay import Node +from graphene_django import DjangoObjectType +from graphene_django.utils import DJANGO_FILTER_INSTALLED + +from ...compat import ArrayField + +pytestmark = [] + +if DJANGO_FILTER_INSTALLED: + from graphene_django.filter import DjangoFilterConnectionField +else: + pytestmark.append( + pytest.mark.skipif( + True, reason="django_filters not installed or not compatible" + ) + ) + + +STORE = {"events": []} + + +@pytest.fixture +def Event(): + class Event(models.Model): + name = models.CharField(max_length=50) + tags = ArrayField(models.CharField(max_length=50)) + + return Event + + +@pytest.fixture +def EventFilterSet(Event): + + from django.contrib.postgres.forms import SimpleArrayField + + class ArrayFilter(filters.Filter): + base_field_class = SimpleArrayField + + class EventFilterSet(FilterSet): + class Meta: + model = Event + fields = { + "name": ["exact"], + } + + tags__contains = ArrayFilter(field_name="tags", lookup_expr="contains") + tags__overlap = ArrayFilter(field_name="tags", lookup_expr="overlap") + + return EventFilterSet + + +@pytest.fixture +def EventType(Event, EventFilterSet): + class EventType(DjangoObjectType): + class Meta: + model = Event + interfaces = (Node,) + filterset_class = EventFilterSet + + return EventType + + +@pytest.fixture +def Query(Event, EventType): + class Query(graphene.ObjectType): + events = DjangoFilterConnectionField(EventType) + + def resolve_events(self, info, **kwargs): + + events = [ + Event(name="Live Show", tags=["concert", "music", "rock"],), + Event(name="Musical", tags=["movie", "music"],), + Event(name="Ballet", tags=["concert", "dance"],), + ] + + STORE["events"] = events + + m_queryset = MagicMock(spec=QuerySet) + m_queryset.model = Event + + def filter_events(**kwargs): + if "tags__contains" in kwargs: + STORE["events"] = list( + filter( + lambda e: set(kwargs["tags__contains"]).issubset( + set(e.tags) + ), + STORE["events"], + ) + ) + if "tags__overlap" in kwargs: + STORE["events"] = list( + filter( + lambda e: not set(kwargs["tags__overlap"]).isdisjoint( + set(e.tags) + ), + STORE["events"], + ) + ) + + def mock_queryset_filter(*args, **kwargs): + filter_events(**kwargs) + return m_queryset + + def mock_queryset_none(*args, **kwargs): + STORE["events"] = [] + return m_queryset + + def mock_queryset_count(*args, **kwargs): + return len(STORE["events"]) + + m_queryset.all.return_value = m_queryset + m_queryset.filter.side_effect = mock_queryset_filter + m_queryset.none.side_effect = mock_queryset_none + m_queryset.count.side_effect = mock_queryset_count + m_queryset.__getitem__.side_effect = STORE["events"].__getitem__ + + return m_queryset + + return Query diff --git a/graphene_django/filter/tests/test_contains_filter.py b/graphene_django/filter/tests/test_contains_filter.py new file mode 100644 index 000000000..35e775ef5 --- /dev/null +++ b/graphene_django/filter/tests/test_contains_filter.py @@ -0,0 +1,82 @@ +import pytest + +from graphene import Schema + +from ...compat import ArrayField, MissingType + + +@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist") +def test_string_contains_multiple(Query): + """ + Test contains filter on a string field. + """ + + schema = Schema(query=Query) + + query = """ + query { + events (tags_Contains: ["concert", "music"]) { + edges { + node { + name + } + } + } + } + """ + result = schema.execute(query) + assert not result.errors + assert result.data["events"]["edges"] == [ + {"node": {"name": "Live Show"}}, + ] + + +@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist") +def test_string_contains_one(Query): + """ + Test contains filter on a string field. + """ + + schema = Schema(query=Query) + + query = """ + query { + events (tags_Contains: ["music"]) { + edges { + node { + name + } + } + } + } + """ + result = schema.execute(query) + assert not result.errors + assert result.data["events"]["edges"] == [ + {"node": {"name": "Live Show"}}, + {"node": {"name": "Musical"}}, + ] + + +@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist") +def test_string_contains_none(Query): + """ + Test contains filter on a string field. + """ + + schema = Schema(query=Query) + + query = """ + query { + events (tags_Contains: []) { + edges { + node { + name + } + } + } + } + """ + result = schema.execute(query) + assert not result.errors + assert result.data["events"]["edges"] == [] diff --git a/graphene_django/filter/tests/test_overlap_filter.py b/graphene_django/filter/tests/test_overlap_filter.py new file mode 100644 index 000000000..32dfa44a1 --- /dev/null +++ b/graphene_django/filter/tests/test_overlap_filter.py @@ -0,0 +1,84 @@ +import pytest + +from graphene import Schema + +from ...compat import ArrayField, MissingType + + +@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist") +def test_string_overlap_multiple(Query): + """ + Test overlap filter on a string field. + """ + + schema = Schema(query=Query) + + query = """ + query { + events (tags_Overlap: ["concert", "music"]) { + edges { + node { + name + } + } + } + } + """ + result = schema.execute(query) + assert not result.errors + assert result.data["events"]["edges"] == [ + {"node": {"name": "Live Show"}}, + {"node": {"name": "Musical"}}, + {"node": {"name": "Ballet"}}, + ] + + +@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist") +def test_string_overlap_one(Query): + """ + Test overlap filter on a string field. + """ + + schema = Schema(query=Query) + + query = """ + query { + events (tags_Overlap: ["music"]) { + edges { + node { + name + } + } + } + } + """ + result = schema.execute(query) + assert not result.errors + assert result.data["events"]["edges"] == [ + {"node": {"name": "Live Show"}}, + {"node": {"name": "Musical"}}, + ] + + +@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist") +def test_string_overlap_none(Query): + """ + Test overlap filter on a string field. + """ + + schema = Schema(query=Query) + + query = """ + query { + events (tags_Overlap: []) { + edges { + node { + name + } + } + } + } + """ + result = schema.execute(query) + assert not result.errors + assert result.data["events"]["edges"] == [] diff --git a/graphene_django/filter/utils.py b/graphene_django/filter/utils.py index dce08c76b..2be3778ba 100644 --- a/graphene_django/filter/utils.py +++ b/graphene_django/filter/utils.py @@ -1,6 +1,6 @@ import six -from graphene import List +import graphene from django_filters.utils import get_model_field from django_filters.filters import Filter, BaseCSVFilter @@ -41,11 +41,11 @@ def get_filtering_args_from_filterset(filterset_class, type): field = convert_form_field(form_field) - if filter_type in ["in", "range"]: - # Replace CSV filters (`in`, `range`) argument type to be a list of + if filter_type in {"in", "range", "contains", "overlap"}: + # Replace CSV filters (`in`, `range`, `contains`, `overlap`) argument type to be a list of # the same type as the field. See comments in # `replace_csv_filters` method for more details. - field = List(field.get_type()) + field = graphene.List(field.get_type()) field_type = field.Argument() field_type.description = filter_field.label @@ -71,7 +71,7 @@ def get_filterset_class(filterset_class, **meta): def replace_csv_filters(filterset_class): """ - Replace the "in" and "range" filters (that are not explicitly declared) to not be BaseCSVFilter (BaseInFilter, BaseRangeFilter) objects anymore + Replace the "in", "contains", "overlap" and "range" filters (that are not explicitly declared) to not be BaseCSVFilter (BaseInFilter, BaseRangeFilter) objects anymore but regular Filter objects that simply use the input value as filter argument on the queryset. This is because those BaseCSVFilter are expecting a string as input with comma separated value but with GraphQl we @@ -81,8 +81,7 @@ def replace_csv_filters(filterset_class): """ for name, filter_field in six.iteritems(filterset_class.base_filters): filter_type = filter_field.lookup_expr - if filter_type == "in": - assert isinstance(filter_field, BaseCSVFilter) + if filter_type in {"in", "contains", "overlap"}: filterset_class.base_filters[name] = InFilter( field_name=filter_field.field_name, lookup_expr=filter_field.lookup_expr, @@ -92,8 +91,7 @@ def replace_csv_filters(filterset_class): **filter_field.extra ) - if filter_type == "range": - assert isinstance(filter_field, BaseCSVFilter) + elif filter_type == "range": filterset_class.base_filters[name] = RangeFilter( field_name=filter_field.field_name, lookup_expr=filter_field.lookup_expr, diff --git a/graphene_django/tests/test_query.py b/graphene_django/tests/test_query.py index a2d8373f0..5ff44664a 100644 --- a/graphene_django/tests/test_query.py +++ b/graphene_django/tests/test_query.py @@ -11,7 +11,7 @@ import graphene from graphene.relay import Node -from ..compat import JSONField, MissingType +from ..compat import IntegerRangeField, MissingType from ..fields import DjangoConnectionField from ..types import DjangoObjectType from ..utils import DJANGO_FILTER_INSTALLED @@ -113,7 +113,7 @@ def resolve_reporter(self, info): assert result.data == expected -@pytest.mark.skipif(JSONField is MissingType, reason="RangeField should exist") +@pytest.mark.skipif(IntegerRangeField is MissingType, reason="RangeField should exist") def test_should_query_postgres_fields(): from django.contrib.postgres.fields import ( IntegerRangeField,