Skip to content

Commit

Permalink
feat: dataloaded list (#44)
Browse files Browse the repository at this point in the history
* feat: DjangoDataloadedListField
  • Loading branch information
superlevure committed Jan 30, 2024
1 parent 62a0ea9 commit 23f51ba
Show file tree
Hide file tree
Showing 6 changed files with 429 additions and 35 deletions.
34 changes: 34 additions & 0 deletions docs/fields.rst
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,40 @@ published and have a title:
schema = Schema(query=Query)
DjangoDataloadedListField
-------------------------

``DjangoDataloadedListField`` allows you to define a dataloaded list of :ref:`DjangoObjectType<queries-objecttypes>`'s from a related :ref:`DjangoObjectType<queries-objecttypes>`.
By default it will resolve the default related queryset of the Django model, but a custom resolver can be defined.

.. code:: python
from graphene import ObjectType, Schema
from graphene_django import DjangoDataloadedListField
class RecipeType(DjangoObjectType):
class Meta:
model = Recipe
fields = "__all__"
ingredients_dataloaded = DjangoDataloadedListField(IngredientType, field="ingredients")
ingredients_dataloaded_custom_resolver = DjangoDataloadedListField(IngredientType, field="ingredients")
def resolve_ingredients_dataloaded_custom_resolver(self, info):
# Important: the queryset returned by the resolver must derivate
# from the root model:
# return self.ingredient.filter(name = "Sugar") # bad
return Ingredient.objects.filter(name = "Sugar") # good
class IngredientType(DjangoObjectType):
class Meta:
model = Ingredient
fields = "__all__"
class Query(ObjectType):
recipes = DjangoListField(RecipeType)
schema = Schema(query=Query)
DjangoConnectionField
---------------------
Expand Down
3 changes: 2 additions & 1 deletion graphene_django/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .fields import DjangoConnectionField, DjangoListField
from .fields import DjangoConnectionField, DjangoDataloadedListField, DjangoListField
from .types import DjangoObjectType
from .utils import bypass_get_queryset

Expand All @@ -8,6 +8,7 @@
"__version__",
"DjangoObjectType",
"DjangoListField",
"DjangoDataloadedListField",
"DjangoConnectionField",
"bypass_get_queryset",
]
14 changes: 7 additions & 7 deletions graphene_django/debug/tests/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from ..types import DjangoDebug


class context:
class Context:
pass


Expand Down Expand Up @@ -59,7 +59,7 @@ def resolve_reporter(self, info, **args):
}
schema = graphene.Schema(query=Query)
result = schema.execute(
query, context_value=context(), middleware=[DjangoDebugMiddleware()]
query, context_value=Context(), middleware=[DjangoDebugMiddleware()]
)
assert not result.errors
assert result.data == expected
Expand Down Expand Up @@ -126,7 +126,7 @@ def resolve_reporter(self, info, **args):
with django_assert_num_queries(3):
result = schema.execute(
query,
context_value=context(),
context_value=Context(),
middleware=[DjangoDebugMiddleware()],
execution_context_class=execution_context_class,
)
Expand Down Expand Up @@ -179,7 +179,7 @@ def resolve_all_reporters(self, info, **args):
}
schema = graphene.Schema(query=Query)
result = schema.execute(
query, context_value=context(), middleware=[DjangoDebugMiddleware()]
query, context_value=Context(), middleware=[DjangoDebugMiddleware()]
)
assert not result.errors
assert result.data == expected
Expand Down Expand Up @@ -230,7 +230,7 @@ def resolve_all_reporters(self, info, **args):
with django_assert_num_queries(1) as captured:
result = schema.execute(
query,
context_value=context(),
context_value=Context(),
middleware=[DjangoDebugMiddleware()],
execution_context_class=execution_context_class,
)
Expand Down Expand Up @@ -293,7 +293,7 @@ def resolve_all_reporters(self, info, **args):
with django_assert_num_queries(1) as captured:
result = schema.execute(
query,
context_value=context(),
context_value=Context(),
middleware=[DjangoDebugMiddleware()],
execution_context_class=execution_context_class,
)
Expand Down Expand Up @@ -337,7 +337,7 @@ def resolve_reporter(self, info, **args):
"""
schema = graphene.Schema(query=Query)
result = schema.execute(
query, context_value=context(), middleware=[DjangoDebugMiddleware()]
query, context_value=Context(), middleware=[DjangoDebugMiddleware()]
)
assert result.errors
assert len(result.data["_debug"]["exceptions"])
Expand Down
112 changes: 110 additions & 2 deletions graphene_django/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Any

import django
from django.db.models import IntegerField, Value
from django.db.models import F, IntegerField, Value
from django.db.models.query import QuerySet
from graphql_relay import (
cursor_to_offset,
Expand Down Expand Up @@ -33,7 +33,7 @@ def __init__(self, _type, *args, **kwargs):
if isinstance(_type, NonNull):
_type = _type.of_type

# Django would never return a Set of None vvvvvvv
# Django would never return a Set of None
super().__init__(List(NonNull(_type)), *args, **kwargs)

@property
Expand Down Expand Up @@ -87,6 +87,114 @@ def wrap_resolve(self, parent_resolver):
)


class DjangoDataloadedListField(Field):
def __init__(
self,
_type,
field,
*args,
**kwargs,
):
from graphene_django.types import DjangoObjectType

if isinstance(_type, NonNull):
_type = _type.of_type

# Django would never return a Set of None
super().__init__(List(NonNull(_type)), *args, **kwargs)

assert issubclass(
self._underlying_type, DjangoObjectType
), "DjangoListField only accepts DjangoObjectType types"

self._field = field

@property
def _underlying_type(self):
_type = self._type
while hasattr(_type, "of_type"):
_type = _type.of_type
return _type

@property
def model(self):
return self._underlying_type._meta.model

def get_manager(self):
return self.model._default_manager

@staticmethod
def list_resolver(
field, django_object_type, resolver, default_manager, root, info, **args
):
related_name = root._meta.get_field(field).remote_field.name
many_to_many = root._meta.get_field(field).many_to_many

queryset = maybe_queryset(resolver(root, info, **args))
if queryset is None:
queryset = maybe_queryset(default_manager)

if isinstance(queryset, QuerySet):
# Pass queryset to the DjangoObjectType get_queryset method
queryset = maybe_queryset(django_object_type.get_queryset(queryset, info))

if info.context is not None and graphene_settings.USE_DATALOADERS:
try:
if not hasattr(info.context, "dataloaders"):
info.context.dataloaders = {}
except AttributeError:
pass
else:
dataloader_key = get_info_cache_key(info)

if dataloader_key not in info.context.dataloaders:

def load_many(keys):
results_by_ids = defaultdict(list)
if many_to_many:
lookup = {
f"{related_name}__in": keys,
}
annotation = {f"{related_name}_id": F(related_name)}
qs = queryset.filter(**lookup).annotate(**annotation)
else:
lookup = {
f"{related_name}_id__in": keys,
}

qs = queryset.filter(**lookup)

for result in qs.iterator():
results_by_ids[
getattr(result, f"{related_name}_id")
].append(result)

return [results_by_ids.get(id, []) for id in keys]

info.context.dataloaders[dataloader_key] = SyncDataLoader(load_many)

return info.context.dataloaders[dataloader_key].load(root.id)

if many_to_many:
return queryset.filter(**{f"{related_name}": root.id})

return queryset.filter(**{f"{related_name}_id": root.id})

def wrap_resolve(self, parent_resolver):
resolver = super().wrap_resolve(parent_resolver)
_type = self.type
if isinstance(_type, NonNull):
_type = _type.of_type
django_object_type = _type.of_type.of_type
return partial(
self.list_resolver,
self._field,
django_object_type,
resolver,
self.get_manager(),
)


class DjangoConnectionField(ConnectionField):
def __init__(self, *args, **kwargs):
self.on = kwargs.pop("on", False)
Expand Down
4 changes: 1 addition & 3 deletions graphene_django/tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,7 @@ class FilmDetails(models.Model):


class Film(models.Model):
class Meta:
ordering = ["pk"]

name = models.CharField(max_length=30)
genre = models.CharField(
max_length=2,
help_text="Genre",
Expand Down
Loading

0 comments on commit 23f51ba

Please sign in to comment.