From 849217a7731cbcfcf64432ccf306f4e4001328f2 Mon Sep 17 00:00:00 2001 From: Chris Berks Date: Thu, 4 Jun 2020 21:13:05 +0100 Subject: [PATCH] Add support for Non-Null SQLAlchemyConnectionField (#261) * Add support for Non-Null SQLAlchemyConnectionField * Remove implicit ORDER BY clause to fix tests with SQLAlchemy 1.3.16 --- graphene_sqlalchemy/fields.py | 50 ++++++++++++++++------ graphene_sqlalchemy/tests/test_batching.py | 8 ++-- graphene_sqlalchemy/tests/test_fields.py | 16 ++++++- 3 files changed, 56 insertions(+), 18 deletions(-) diff --git a/graphene_sqlalchemy/fields.py b/graphene_sqlalchemy/fields.py index 254319f9..780fcbf0 100644 --- a/graphene_sqlalchemy/fields.py +++ b/graphene_sqlalchemy/fields.py @@ -5,6 +5,7 @@ from promise import Promise, is_thenable from sqlalchemy.orm.query import Query +from graphene import NonNull from graphene.relay import Connection, ConnectionField from graphene.relay.connection import PageInfo from graphql_relay.connection.arrayconnection import connection_from_list_slice @@ -19,19 +20,26 @@ def type(self): from .types import SQLAlchemyObjectType _type = super(ConnectionField, self).type - if issubclass(_type, Connection): + nullable_type = get_nullable_type(_type) + if issubclass(nullable_type, Connection): return _type - assert issubclass(_type, SQLAlchemyObjectType), ( + assert issubclass(nullable_type, SQLAlchemyObjectType), ( "SQLALchemyConnectionField only accepts SQLAlchemyObjectType types, not {}" - ).format(_type.__name__) - assert _type.connection, "The type {} doesn't have a connection".format( - _type.__name__ + ).format(nullable_type.__name__) + assert ( + nullable_type.connection + ), "The type {} doesn't have a connection".format( + nullable_type.__name__ ) - return _type.connection + assert _type == nullable_type, ( + "Passing a SQLAlchemyObjectType instance is deprecated. " + "Pass the connection type instead accessible via SQLAlchemyObjectType.connection" + ) + return nullable_type.connection @property def model(self): - return self.type._meta.node._meta.model + return get_nullable_type(self.type)._meta.node._meta.model @classmethod def get_query(cls, model, info, **args): @@ -70,21 +78,27 @@ def connection_resolver(cls, resolver, connection_type, model, root, info, **arg return on_resolve(resolved) def get_resolver(self, parent_resolver): - return partial(self.connection_resolver, parent_resolver, self.type, self.model) + return partial( + self.connection_resolver, + parent_resolver, + get_nullable_type(self.type), + self.model, + ) # TODO Rename this to SortableSQLAlchemyConnectionField class SQLAlchemyConnectionField(UnsortedSQLAlchemyConnectionField): def __init__(self, type, *args, **kwargs): - if "sort" not in kwargs and issubclass(type, Connection): + nullable_type = get_nullable_type(type) + if "sort" not in kwargs and issubclass(nullable_type, Connection): # Let super class raise if type is not a Connection try: - kwargs.setdefault("sort", type.Edge.node._type.sort_argument()) + kwargs.setdefault("sort", nullable_type.Edge.node._type.sort_argument()) except (AttributeError, TypeError): raise TypeError( 'Cannot create sort argument for {}. A model is required. Set the "sort" argument' " to None to disabling the creation of the sort query argument".format( - type.__name__ + nullable_type.__name__ ) ) elif "sort" in kwargs and kwargs["sort"] is None: @@ -108,8 +122,14 @@ class BatchSQLAlchemyConnectionField(UnsortedSQLAlchemyConnectionField): The API and behavior may change in future versions. Use at your own risk. """ + def get_resolver(self, parent_resolver): - return partial(self.connection_resolver, self.resolver, self.type, self.model) + return partial( + self.connection_resolver, + self.resolver, + get_nullable_type(self.type), + self.model, + ) @classmethod def from_relationship(cls, relationship, registry, **field_kwargs): @@ -155,3 +175,9 @@ def unregisterConnectionFieldFactory(): ) global __connectionFactory __connectionFactory = UnsortedSQLAlchemyConnectionField + + +def get_nullable_type(_type): + if isinstance(_type, NonNull): + return _type.of_type + return _type diff --git a/graphene_sqlalchemy/tests/test_batching.py b/graphene_sqlalchemy/tests/test_batching.py index b97002a7..fc646a3c 100644 --- a/graphene_sqlalchemy/tests/test_batching.py +++ b/graphene_sqlalchemy/tests/test_batching.py @@ -233,8 +233,7 @@ def test_one_to_one(session_factory): 'articles.headline AS articles_headline, ' 'articles.pub_date AS articles_pub_date \n' 'FROM articles \n' - 'WHERE articles.reporter_id IN (?, ?) ' - 'ORDER BY articles.reporter_id', + 'WHERE articles.reporter_id IN (?, ?)', '(1, 2)' ] @@ -337,8 +336,7 @@ def test_one_to_many(session_factory): 'articles.headline AS articles_headline, ' 'articles.pub_date AS articles_pub_date \n' 'FROM articles \n' - 'WHERE articles.reporter_id IN (?, ?) ' - 'ORDER BY articles.reporter_id', + 'WHERE articles.reporter_id IN (?, ?)', '(1, 2)' ] @@ -470,7 +468,7 @@ def test_many_to_many(session_factory): 'JOIN association AS association_1 ON reporters_1.id = association_1.reporter_id ' 'JOIN pets ON pets.id = association_1.pet_id \n' 'WHERE reporters_1.id IN (?, ?) ' - 'ORDER BY reporters_1.id, pets.id', + 'ORDER BY pets.id', '(1, 2)' ] diff --git a/graphene_sqlalchemy/tests/test_fields.py b/graphene_sqlalchemy/tests/test_fields.py index 9ed3c4aa..357055e3 100644 --- a/graphene_sqlalchemy/tests/test_fields.py +++ b/graphene_sqlalchemy/tests/test_fields.py @@ -1,7 +1,7 @@ import pytest from promise import Promise -from graphene import ObjectType +from graphene import NonNull, ObjectType from graphene.relay import Connection, Node from ..fields import (SQLAlchemyConnectionField, @@ -26,6 +26,20 @@ class Meta: ## +def test_nonnull_sqlalachemy_connection(): + field = SQLAlchemyConnectionField(NonNull(Pet.connection)) + assert isinstance(field.type, NonNull) + assert issubclass(field.type.of_type, Connection) + assert field.type.of_type._meta.node is Pet + + +def test_required_sqlalachemy_connection(): + field = SQLAlchemyConnectionField(Pet.connection, required=True) + assert isinstance(field.type, NonNull) + assert issubclass(field.type.of_type, Connection) + assert field.type.of_type._meta.node is Pet + + def test_promise_connection_resolver(): def resolver(_obj, _info): return Promise.resolve([])