Skip to content

Commit

Permalink
Handle the default django list field and test the async execution of …
Browse files Browse the repository at this point in the history
…the fields
  • Loading branch information
jaw9c committed May 5, 2023
1 parent c10753d commit e9d5e88
Show file tree
Hide file tree
Showing 3 changed files with 184 additions and 13 deletions.
52 changes: 39 additions & 13 deletions graphene_django/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,20 +53,41 @@ def get_manager(self):
def list_resolver(
django_object_type, resolver, default_manager, root, info, **args
):
queryset = maybe_queryset(resolver(root, info, **args))
iterable = resolver(root, info, **args)

if info.is_awaitable(iterable):

async def resolve_list_async(iterable):
queryset = maybe_queryset(await iterable)
if queryset is None:
queryset = maybe_queryset(default_manager)

if isinstance(queryset, QuerySet):
# Pass queryset to the DjangoObjectType get_queryset method
queryset = maybe_queryset(
await sync_to_async(django_object_type.get_queryset)(
queryset, info
)
)

return await sync_to_async(list)(queryset)

return resolve_list_async(iterable)

queryset = maybe_queryset(iterable)
if queryset is None:
queryset = maybe_queryset(default_manager)

if isinstance(queryset, QuerySet):
# Pass queryset to the DjangoObjectType get_queryset method
queryset = maybe_queryset(django_object_type.get_queryset(queryset, info))

try:
try:
get_running_loop()
except RuntimeError:
pass
pass
else:
return queryset.aiterator()
return sync_to_async(list)(queryset)

return queryset

Expand Down Expand Up @@ -238,34 +259,39 @@ def connection_resolver(
# or a resolve_foo (does not accept queryset)

iterable = resolver(root, info, **args)

if info.is_awaitable(iterable):

async def resolve_connection_async(iterable):
iterable = await iterable
if iterable is None:
iterable = default_manager
## This could also be async
iterable = queryset_resolver(connection, iterable, info, args)

if info.is_awaitable(iterable):
iterable = await iterable

return await sync_to_async(cls.resolve_connection)(connection, args, iterable, max_limit=max_limit)

return await sync_to_async(cls.resolve_connection)(
connection, args, iterable, max_limit=max_limit
)

return resolve_connection_async(iterable)

if iterable is None:
iterable = default_manager
# thus the iterable gets refiltered by resolve_queryset
# but iterable might be promise
iterable = queryset_resolver(connection, iterable, info, args)

try:
try:
get_running_loop()
except RuntimeError:
pass
pass
else:
return sync_to_async(cls.resolve_connection)(connection, args, iterable, max_limit=max_limit)

return sync_to_async(cls.resolve_connection)(
connection, args, iterable, max_limit=max_limit
)

return cls.resolve_connection(connection, args, iterable, max_limit=max_limit)

Expand Down
6 changes: 6 additions & 0 deletions graphene_django/tests/async_test_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from asgiref.sync import async_to_sync


def assert_async_result_equal(schema, query, result):
async_result = async_to_sync(schema.execute_async)(query)
assert async_result == result
139 changes: 139 additions & 0 deletions graphene_django/tests/test_fields.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import datetime
import re
from django.db.models import Count, Prefetch
from asgiref.sync import sync_to_async, async_to_sync

import pytest

Expand All @@ -14,6 +15,7 @@
FilmDetails as FilmDetailsModel,
Reporter as ReporterModel,
)
from .async_test_helper import assert_async_result_equal


class TestDjangoListField:
Expand Down Expand Up @@ -75,6 +77,7 @@ class Query(ObjectType):

result = schema.execute(query)

assert_async_result_equal(schema, query, result)
assert not result.errors
assert result.data == {
"reporters": [{"firstName": "Tara"}, {"firstName": "Debra"}]
Expand Down Expand Up @@ -102,6 +105,7 @@ class Query(ObjectType):
result = schema.execute(query)
assert not result.errors
assert result.data == {"reporters": []}
assert_async_result_equal(schema, query, result)

ReporterModel.objects.create(first_name="Tara", last_name="West")
ReporterModel.objects.create(first_name="Debra", last_name="Payne")
Expand All @@ -112,6 +116,7 @@ class Query(ObjectType):
assert result.data == {
"reporters": [{"firstName": "Tara"}, {"firstName": "Debra"}]
}
assert_async_result_equal(schema, query, result)

def test_override_resolver(self):
class Reporter(DjangoObjectType):
Expand Down Expand Up @@ -139,6 +144,37 @@ def resolve_reporters(_, info):
ReporterModel.objects.create(first_name="Debra", last_name="Payne")

result = schema.execute(query)
assert not result.errors
assert result.data == {"reporters": [{"firstName": "Tara"}]}

def test_override_resolver_async_execution(self):
class Reporter(DjangoObjectType):
class Meta:
model = ReporterModel
fields = ("first_name",)

class Query(ObjectType):
reporters = DjangoListField(Reporter)

@staticmethod
@sync_to_async
def resolve_reporters(_, info):
return ReporterModel.objects.filter(first_name="Tara")

schema = Schema(query=Query)

query = """
query {
reporters {
firstName
}
}
"""

ReporterModel.objects.create(first_name="Tara", last_name="West")
ReporterModel.objects.create(first_name="Debra", last_name="Payne")

result = async_to_sync(schema.execute_async)(query)

assert not result.errors
assert result.data == {"reporters": [{"firstName": "Tara"}]}
Expand Down Expand Up @@ -203,6 +239,7 @@ class Query(ObjectType):
{"firstName": "Debra", "articles": []},
]
}
assert_async_result_equal(schema, query, result)

def test_override_resolver_nested_list_field(self):
class Article(DjangoObjectType):
Expand Down Expand Up @@ -261,6 +298,7 @@ class Query(ObjectType):
{"firstName": "Debra", "articles": []},
]
}
assert_async_result_equal(schema, query, result)

def test_get_queryset_filter(self):
class Reporter(DjangoObjectType):
Expand Down Expand Up @@ -306,6 +344,7 @@ def resolve_reporters(_, info):

assert not result.errors
assert result.data == {"reporters": [{"firstName": "Tara"}]}
assert_async_result_equal(schema, query, result)

def test_resolve_list(self):
"""Resolving a plain list should work (and not call get_queryset)"""
Expand Down Expand Up @@ -354,6 +393,55 @@ def resolve_reporters(_, info):
assert not result.errors
assert result.data == {"reporters": [{"firstName": "Debra"}]}

def test_resolve_list_async(self):
"""Resolving a plain list should work (and not call get_queryset) when running under async"""

class Reporter(DjangoObjectType):
class Meta:
model = ReporterModel
fields = ("first_name", "articles")

@classmethod
def get_queryset(cls, queryset, info):
# Only get reporters with at least 1 article
return queryset.annotate(article_count=Count("articles")).filter(
article_count__gt=0
)

class Query(ObjectType):
reporters = DjangoListField(Reporter)

@staticmethod
@sync_to_async
def resolve_reporters(_, info):
return [ReporterModel.objects.get(first_name="Debra")]

schema = Schema(query=Query)

query = """
query {
reporters {
firstName
}
}
"""

r1 = ReporterModel.objects.create(first_name="Tara", last_name="West")
ReporterModel.objects.create(first_name="Debra", last_name="Payne")

ArticleModel.objects.create(
headline="Amazing news",
reporter=r1,
pub_date=datetime.date.today(),
pub_date_time=datetime.datetime.now(),
editor=r1,
)

result = async_to_sync(schema.execute_async)(query)

assert not result.errors
assert result.data == {"reporters": [{"firstName": "Debra"}]}

def test_get_queryset_foreign_key(self):
class Article(DjangoObjectType):
class Meta:
Expand Down Expand Up @@ -413,6 +501,7 @@ class Query(ObjectType):
{"firstName": "Debra", "articles": []},
]
}
assert_async_result_equal(schema, query, result)

def test_resolve_list_external_resolver(self):
"""Resolving a plain list from external resolver should work (and not call get_queryset)"""
Expand Down Expand Up @@ -461,6 +550,54 @@ class Query(ObjectType):
assert not result.errors
assert result.data == {"reporters": [{"firstName": "Debra"}]}

def test_resolve_list_external_resolver_async(self):
"""Resolving a plain list from external resolver should work (and not call get_queryset)"""

class Reporter(DjangoObjectType):
class Meta:
model = ReporterModel
fields = ("first_name", "articles")

@classmethod
def get_queryset(cls, queryset, info):
# Only get reporters with at least 1 article
return queryset.annotate(article_count=Count("articles")).filter(
article_count__gt=0
)

@sync_to_async
def resolve_reporters(_, info):
return [ReporterModel.objects.get(first_name="Debra")]

class Query(ObjectType):
reporters = DjangoListField(Reporter, resolver=resolve_reporters)

schema = Schema(query=Query)

query = """
query {
reporters {
firstName
}
}
"""

r1 = ReporterModel.objects.create(first_name="Tara", last_name="West")
ReporterModel.objects.create(first_name="Debra", last_name="Payne")

ArticleModel.objects.create(
headline="Amazing news",
reporter=r1,
pub_date=datetime.date.today(),
pub_date_time=datetime.datetime.now(),
editor=r1,
)

result = async_to_sync(schema.execute_async)(query)

assert not result.errors
assert result.data == {"reporters": [{"firstName": "Debra"}]}

def test_get_queryset_filter_external_resolver(self):
class Reporter(DjangoObjectType):
class Meta:
Expand Down Expand Up @@ -505,6 +642,7 @@ class Query(ObjectType):

assert not result.errors
assert result.data == {"reporters": [{"firstName": "Tara"}]}
assert_async_result_equal(schema, query, result)

def test_select_related_and_prefetch_related_are_respected(
self, django_assert_num_queries
Expand Down Expand Up @@ -647,3 +785,4 @@ def resolve_articles(root, info):
r'SELECT .* FROM "tests_film" INNER JOIN "tests_film_reporters" .* LEFT OUTER JOIN "tests_filmdetails"',
captured.captured_queries[1]["sql"],
)
assert_async_result_equal(schema, query, result)

0 comments on commit e9d5e88

Please sign in to comment.