Skip to content

Commit

Permalink
Add support for Non-Null SQLAlchemyConnectionField (#261)
Browse files Browse the repository at this point in the history
* Add support for Non-Null SQLAlchemyConnectionField
* Remove implicit ORDER BY clause to fix tests with SQLAlchemy 1.3.16
  • Loading branch information
chrisberks authored Jun 4, 2020
1 parent 421f8e4 commit 849217a
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 18 deletions.
50 changes: 38 additions & 12 deletions graphene_sqlalchemy/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from promise import Promise, is_thenable
from sqlalchemy.orm.query import Query

from graphene import NonNull
from graphene.relay import Connection, ConnectionField
from graphene.relay.connection import PageInfo
from graphql_relay.connection.arrayconnection import connection_from_list_slice
Expand All @@ -19,19 +20,26 @@ def type(self):
from .types import SQLAlchemyObjectType

_type = super(ConnectionField, self).type
if issubclass(_type, Connection):
nullable_type = get_nullable_type(_type)
if issubclass(nullable_type, Connection):
return _type
assert issubclass(_type, SQLAlchemyObjectType), (
assert issubclass(nullable_type, SQLAlchemyObjectType), (
"SQLALchemyConnectionField only accepts SQLAlchemyObjectType types, not {}"
).format(_type.__name__)
assert _type.connection, "The type {} doesn't have a connection".format(
_type.__name__
).format(nullable_type.__name__)
assert (
nullable_type.connection
), "The type {} doesn't have a connection".format(
nullable_type.__name__
)
return _type.connection
assert _type == nullable_type, (
"Passing a SQLAlchemyObjectType instance is deprecated. "
"Pass the connection type instead accessible via SQLAlchemyObjectType.connection"
)
return nullable_type.connection

@property
def model(self):
return self.type._meta.node._meta.model
return get_nullable_type(self.type)._meta.node._meta.model

@classmethod
def get_query(cls, model, info, **args):
Expand Down Expand Up @@ -70,21 +78,27 @@ def connection_resolver(cls, resolver, connection_type, model, root, info, **arg
return on_resolve(resolved)

def get_resolver(self, parent_resolver):
return partial(self.connection_resolver, parent_resolver, self.type, self.model)
return partial(
self.connection_resolver,
parent_resolver,
get_nullable_type(self.type),
self.model,
)


# TODO Rename this to SortableSQLAlchemyConnectionField
class SQLAlchemyConnectionField(UnsortedSQLAlchemyConnectionField):
def __init__(self, type, *args, **kwargs):
if "sort" not in kwargs and issubclass(type, Connection):
nullable_type = get_nullable_type(type)
if "sort" not in kwargs and issubclass(nullable_type, Connection):
# Let super class raise if type is not a Connection
try:
kwargs.setdefault("sort", type.Edge.node._type.sort_argument())
kwargs.setdefault("sort", nullable_type.Edge.node._type.sort_argument())
except (AttributeError, TypeError):
raise TypeError(
'Cannot create sort argument for {}. A model is required. Set the "sort" argument'
" to None to disabling the creation of the sort query argument".format(
type.__name__
nullable_type.__name__
)
)
elif "sort" in kwargs and kwargs["sort"] is None:
Expand All @@ -108,8 +122,14 @@ class BatchSQLAlchemyConnectionField(UnsortedSQLAlchemyConnectionField):
The API and behavior may change in future versions.
Use at your own risk.
"""

def get_resolver(self, parent_resolver):
return partial(self.connection_resolver, self.resolver, self.type, self.model)
return partial(
self.connection_resolver,
self.resolver,
get_nullable_type(self.type),
self.model,
)

@classmethod
def from_relationship(cls, relationship, registry, **field_kwargs):
Expand Down Expand Up @@ -155,3 +175,9 @@ def unregisterConnectionFieldFactory():
)
global __connectionFactory
__connectionFactory = UnsortedSQLAlchemyConnectionField


def get_nullable_type(_type):
if isinstance(_type, NonNull):
return _type.of_type
return _type
8 changes: 3 additions & 5 deletions graphene_sqlalchemy/tests/test_batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,8 +233,7 @@ def test_one_to_one(session_factory):
'articles.headline AS articles_headline, '
'articles.pub_date AS articles_pub_date \n'
'FROM articles \n'
'WHERE articles.reporter_id IN (?, ?) '
'ORDER BY articles.reporter_id',
'WHERE articles.reporter_id IN (?, ?)',
'(1, 2)'
]

Expand Down Expand Up @@ -337,8 +336,7 @@ def test_one_to_many(session_factory):
'articles.headline AS articles_headline, '
'articles.pub_date AS articles_pub_date \n'
'FROM articles \n'
'WHERE articles.reporter_id IN (?, ?) '
'ORDER BY articles.reporter_id',
'WHERE articles.reporter_id IN (?, ?)',
'(1, 2)'
]

Expand Down Expand Up @@ -470,7 +468,7 @@ def test_many_to_many(session_factory):
'JOIN association AS association_1 ON reporters_1.id = association_1.reporter_id '
'JOIN pets ON pets.id = association_1.pet_id \n'
'WHERE reporters_1.id IN (?, ?) '
'ORDER BY reporters_1.id, pets.id',
'ORDER BY pets.id',
'(1, 2)'
]

Expand Down
16 changes: 15 additions & 1 deletion graphene_sqlalchemy/tests/test_fields.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest
from promise import Promise

from graphene import ObjectType
from graphene import NonNull, ObjectType
from graphene.relay import Connection, Node

from ..fields import (SQLAlchemyConnectionField,
Expand All @@ -26,6 +26,20 @@ class Meta:
##


def test_nonnull_sqlalachemy_connection():
field = SQLAlchemyConnectionField(NonNull(Pet.connection))
assert isinstance(field.type, NonNull)
assert issubclass(field.type.of_type, Connection)
assert field.type.of_type._meta.node is Pet


def test_required_sqlalachemy_connection():
field = SQLAlchemyConnectionField(Pet.connection, required=True)
assert isinstance(field.type, NonNull)
assert issubclass(field.type.of_type, Connection)
assert field.type.of_type._meta.node is Pet


def test_promise_connection_resolver():
def resolver(_obj, _info):
return Promise.resolve([])
Expand Down

0 comments on commit 849217a

Please sign in to comment.