From e9d5e88ea25b68c57c4a07a576010ea60f2dbfbe Mon Sep 17 00:00:00 2001 From: Josh Warwick Date: Fri, 5 May 2023 11:18:21 +0100 Subject: [PATCH] Handle the default django list field and test the async execution of the fields --- graphene_django/fields.py | 52 ++++++-- graphene_django/tests/async_test_helper.py | 6 + graphene_django/tests/test_fields.py | 139 +++++++++++++++++++++ 3 files changed, 184 insertions(+), 13 deletions(-) create mode 100644 graphene_django/tests/async_test_helper.py diff --git a/graphene_django/fields.py b/graphene_django/fields.py index 6e1b0b1de..99e84c76b 100644 --- a/graphene_django/fields.py +++ b/graphene_django/fields.py @@ -53,7 +53,28 @@ def get_manager(self): def list_resolver( django_object_type, resolver, default_manager, root, info, **args ): - queryset = maybe_queryset(resolver(root, info, **args)) + iterable = resolver(root, info, **args) + + if info.is_awaitable(iterable): + + async def resolve_list_async(iterable): + queryset = maybe_queryset(await iterable) + if queryset is None: + queryset = maybe_queryset(default_manager) + + if isinstance(queryset, QuerySet): + # Pass queryset to the DjangoObjectType get_queryset method + queryset = maybe_queryset( + await sync_to_async(django_object_type.get_queryset)( + queryset, info + ) + ) + + return await sync_to_async(list)(queryset) + + return resolve_list_async(iterable) + + queryset = maybe_queryset(iterable) if queryset is None: queryset = maybe_queryset(default_manager) @@ -61,12 +82,12 @@ def list_resolver( # Pass queryset to the DjangoObjectType get_queryset method queryset = maybe_queryset(django_object_type.get_queryset(queryset, info)) - try: + try: get_running_loop() except RuntimeError: - pass + pass else: - return queryset.aiterator() + return sync_to_async(list)(queryset) return queryset @@ -238,34 +259,39 @@ def connection_resolver( # or a resolve_foo (does not accept queryset) iterable = resolver(root, info, **args) - + if info.is_awaitable(iterable): + async def resolve_connection_async(iterable): iterable = await iterable if iterable is None: iterable = default_manager ## This could also be async iterable = queryset_resolver(connection, iterable, info, args) - + if info.is_awaitable(iterable): iterable = await iterable - - return await sync_to_async(cls.resolve_connection)(connection, args, iterable, max_limit=max_limit) + + return await sync_to_async(cls.resolve_connection)( + connection, args, iterable, max_limit=max_limit + ) + return resolve_connection_async(iterable) - + if iterable is None: iterable = default_manager # thus the iterable gets refiltered by resolve_queryset # but iterable might be promise iterable = queryset_resolver(connection, iterable, info, args) - try: + try: get_running_loop() except RuntimeError: - pass + pass else: - return sync_to_async(cls.resolve_connection)(connection, args, iterable, max_limit=max_limit) - + return sync_to_async(cls.resolve_connection)( + connection, args, iterable, max_limit=max_limit + ) return cls.resolve_connection(connection, args, iterable, max_limit=max_limit) diff --git a/graphene_django/tests/async_test_helper.py b/graphene_django/tests/async_test_helper.py new file mode 100644 index 000000000..0487f8918 --- /dev/null +++ b/graphene_django/tests/async_test_helper.py @@ -0,0 +1,6 @@ +from asgiref.sync import async_to_sync + + +def assert_async_result_equal(schema, query, result): + async_result = async_to_sync(schema.execute_async)(query) + assert async_result == result diff --git a/graphene_django/tests/test_fields.py b/graphene_django/tests/test_fields.py index 8c7b78d36..2a5055398 100644 --- a/graphene_django/tests/test_fields.py +++ b/graphene_django/tests/test_fields.py @@ -1,6 +1,7 @@ import datetime import re from django.db.models import Count, Prefetch +from asgiref.sync import sync_to_async, async_to_sync import pytest @@ -14,6 +15,7 @@ FilmDetails as FilmDetailsModel, Reporter as ReporterModel, ) +from .async_test_helper import assert_async_result_equal class TestDjangoListField: @@ -75,6 +77,7 @@ class Query(ObjectType): result = schema.execute(query) + assert_async_result_equal(schema, query, result) assert not result.errors assert result.data == { "reporters": [{"firstName": "Tara"}, {"firstName": "Debra"}] @@ -102,6 +105,7 @@ class Query(ObjectType): result = schema.execute(query) assert not result.errors assert result.data == {"reporters": []} + assert_async_result_equal(schema, query, result) ReporterModel.objects.create(first_name="Tara", last_name="West") ReporterModel.objects.create(first_name="Debra", last_name="Payne") @@ -112,6 +116,7 @@ class Query(ObjectType): assert result.data == { "reporters": [{"firstName": "Tara"}, {"firstName": "Debra"}] } + assert_async_result_equal(schema, query, result) def test_override_resolver(self): class Reporter(DjangoObjectType): @@ -139,6 +144,37 @@ def resolve_reporters(_, info): ReporterModel.objects.create(first_name="Debra", last_name="Payne") result = schema.execute(query) + assert not result.errors + assert result.data == {"reporters": [{"firstName": "Tara"}]} + + def test_override_resolver_async_execution(self): + class Reporter(DjangoObjectType): + class Meta: + model = ReporterModel + fields = ("first_name",) + + class Query(ObjectType): + reporters = DjangoListField(Reporter) + + @staticmethod + @sync_to_async + def resolve_reporters(_, info): + return ReporterModel.objects.filter(first_name="Tara") + + schema = Schema(query=Query) + + query = """ + query { + reporters { + firstName + } + } + """ + + ReporterModel.objects.create(first_name="Tara", last_name="West") + ReporterModel.objects.create(first_name="Debra", last_name="Payne") + + result = async_to_sync(schema.execute_async)(query) assert not result.errors assert result.data == {"reporters": [{"firstName": "Tara"}]} @@ -203,6 +239,7 @@ class Query(ObjectType): {"firstName": "Debra", "articles": []}, ] } + assert_async_result_equal(schema, query, result) def test_override_resolver_nested_list_field(self): class Article(DjangoObjectType): @@ -261,6 +298,7 @@ class Query(ObjectType): {"firstName": "Debra", "articles": []}, ] } + assert_async_result_equal(schema, query, result) def test_get_queryset_filter(self): class Reporter(DjangoObjectType): @@ -306,6 +344,7 @@ def resolve_reporters(_, info): assert not result.errors assert result.data == {"reporters": [{"firstName": "Tara"}]} + assert_async_result_equal(schema, query, result) def test_resolve_list(self): """Resolving a plain list should work (and not call get_queryset)""" @@ -354,6 +393,55 @@ def resolve_reporters(_, info): assert not result.errors assert result.data == {"reporters": [{"firstName": "Debra"}]} + def test_resolve_list_async(self): + """Resolving a plain list should work (and not call get_queryset) when running under async""" + + class Reporter(DjangoObjectType): + class Meta: + model = ReporterModel + fields = ("first_name", "articles") + + @classmethod + def get_queryset(cls, queryset, info): + # Only get reporters with at least 1 article + return queryset.annotate(article_count=Count("articles")).filter( + article_count__gt=0 + ) + + class Query(ObjectType): + reporters = DjangoListField(Reporter) + + @staticmethod + @sync_to_async + def resolve_reporters(_, info): + return [ReporterModel.objects.get(first_name="Debra")] + + schema = Schema(query=Query) + + query = """ + query { + reporters { + firstName + } + } + """ + + r1 = ReporterModel.objects.create(first_name="Tara", last_name="West") + ReporterModel.objects.create(first_name="Debra", last_name="Payne") + + ArticleModel.objects.create( + headline="Amazing news", + reporter=r1, + pub_date=datetime.date.today(), + pub_date_time=datetime.datetime.now(), + editor=r1, + ) + + result = async_to_sync(schema.execute_async)(query) + + assert not result.errors + assert result.data == {"reporters": [{"firstName": "Debra"}]} + def test_get_queryset_foreign_key(self): class Article(DjangoObjectType): class Meta: @@ -413,6 +501,7 @@ class Query(ObjectType): {"firstName": "Debra", "articles": []}, ] } + assert_async_result_equal(schema, query, result) def test_resolve_list_external_resolver(self): """Resolving a plain list from external resolver should work (and not call get_queryset)""" @@ -461,6 +550,54 @@ class Query(ObjectType): assert not result.errors assert result.data == {"reporters": [{"firstName": "Debra"}]} + def test_resolve_list_external_resolver_async(self): + """Resolving a plain list from external resolver should work (and not call get_queryset)""" + + class Reporter(DjangoObjectType): + class Meta: + model = ReporterModel + fields = ("first_name", "articles") + + @classmethod + def get_queryset(cls, queryset, info): + # Only get reporters with at least 1 article + return queryset.annotate(article_count=Count("articles")).filter( + article_count__gt=0 + ) + + @sync_to_async + def resolve_reporters(_, info): + return [ReporterModel.objects.get(first_name="Debra")] + + class Query(ObjectType): + reporters = DjangoListField(Reporter, resolver=resolve_reporters) + + schema = Schema(query=Query) + + query = """ + query { + reporters { + firstName + } + } + """ + + r1 = ReporterModel.objects.create(first_name="Tara", last_name="West") + ReporterModel.objects.create(first_name="Debra", last_name="Payne") + + ArticleModel.objects.create( + headline="Amazing news", + reporter=r1, + pub_date=datetime.date.today(), + pub_date_time=datetime.datetime.now(), + editor=r1, + ) + + result = async_to_sync(schema.execute_async)(query) + + assert not result.errors + assert result.data == {"reporters": [{"firstName": "Debra"}]} + def test_get_queryset_filter_external_resolver(self): class Reporter(DjangoObjectType): class Meta: @@ -505,6 +642,7 @@ class Query(ObjectType): assert not result.errors assert result.data == {"reporters": [{"firstName": "Tara"}]} + assert_async_result_equal(schema, query, result) def test_select_related_and_prefetch_related_are_respected( self, django_assert_num_queries @@ -647,3 +785,4 @@ def resolve_articles(root, info): r'SELECT .* FROM "tests_film" INNER JOIN "tests_film_reporters" .* LEFT OUTER JOIN "tests_filmdetails"', captured.captured_queries[1]["sql"], ) + assert_async_result_equal(schema, query, result)