diff --git a/graphene_sqlalchemy/batching.py b/graphene_sqlalchemy/batching.py index e56b1e4c..f6f14a6e 100644 --- a/graphene_sqlalchemy/batching.py +++ b/graphene_sqlalchemy/batching.py @@ -1,5 +1,6 @@ """The dataloader uses "select in loading" strategy to load related entities.""" -from typing import Any +from asyncio import get_event_loop +from typing import Any, Dict import aiodataloader import sqlalchemy @@ -10,6 +11,90 @@ is_sqlalchemy_version_less_than) +class RelationshipLoader(aiodataloader.DataLoader): + cache = False + + def __init__(self, relationship_prop, selectin_loader): + super().__init__() + self.relationship_prop = relationship_prop + self.selectin_loader = selectin_loader + + async def batch_load_fn(self, parents): + """ + Batch loads the relationships of all the parents as one SQL statement. + + There is no way to do this out-of-the-box with SQLAlchemy but + we can piggyback on some internal APIs of the `selectin` + eager loading strategy. It's a bit hacky but it's preferable + than re-implementing and maintainnig a big chunk of the `selectin` + loader logic ourselves. + + The approach here is to build a regular query that + selects the parent and `selectin` load the relationship. + But instead of having the query emits 2 `SELECT` statements + when callling `all()`, we skip the first `SELECT` statement + and jump right before the `selectin` loader is called. + To accomplish this, we have to construct objects that are + normally built in the first part of the query in order + to call directly `SelectInLoader._load_for_path`. + + TODO Move this logic to a util in the SQLAlchemy repo as per + SQLAlchemy's main maitainer suggestion. + See https://git.io/JewQ7 + """ + child_mapper = self.relationship_prop.mapper + parent_mapper = self.relationship_prop.parent + session = Session.object_session(parents[0]) + + # These issues are very unlikely to happen in practice... + for parent in parents: + # assert parent.__mapper__ is parent_mapper + # All instances must share the same session + assert session is Session.object_session(parent) + # The behavior of `selectin` is undefined if the parent is dirty + assert parent not in session.dirty + + # Should the boolean be set to False? Does it matter for our purposes? + states = [(sqlalchemy.inspect(parent), True) for parent in parents] + + # For our purposes, the query_context will only used to get the session + query_context = None + if is_sqlalchemy_version_less_than('1.4'): + query_context = QueryContext(session.query(parent_mapper.entity)) + else: + parent_mapper_query = session.query(parent_mapper.entity) + query_context = parent_mapper_query._compile_context() + + if is_sqlalchemy_version_less_than('1.4'): + self.selectin_loader._load_for_path( + query_context, + parent_mapper._path_registry, + states, + None, + child_mapper, + ) + else: + self.selectin_loader._load_for_path( + query_context, + parent_mapper._path_registry, + states, + None, + child_mapper, + None, + ) + return [ + getattr(parent, self.relationship_prop.key) for parent in parents + ] + + +# Cache this across `batch_load_fn` calls +# This is so SQL string generation is cached under-the-hood via `bakery` +# Caching the relationship loader for each relationship prop. +RELATIONSHIP_LOADERS_CACHE: Dict[ + sqlalchemy.orm.relationships.RelationshipProperty, RelationshipLoader +] = {} + + def get_data_loader_impl() -> Any: # pragma: no cover """Graphene >= 3.1.1 ships a copy of aiodataloader with minor fixes. To preserve backward-compatibility, aiodataloader is used in conjunction with older versions of graphene""" @@ -25,80 +110,23 @@ def get_data_loader_impl() -> Any: # pragma: no cover def get_batch_resolver(relationship_prop): - # Cache this across `batch_load_fn` calls - # This is so SQL string generation is cached under-the-hood via `bakery` - selectin_loader = strategies.SelectInLoader(relationship_prop, (('lazy', 'selectin'),)) - - class RelationshipLoader(aiodataloader.DataLoader): - cache = False - - async def batch_load_fn(self, parents): - """ - Batch loads the relationships of all the parents as one SQL statement. - - There is no way to do this out-of-the-box with SQLAlchemy but - we can piggyback on some internal APIs of the `selectin` - eager loading strategy. It's a bit hacky but it's preferable - than re-implementing and maintainnig a big chunk of the `selectin` - loader logic ourselves. - - The approach here is to build a regular query that - selects the parent and `selectin` load the relationship. - But instead of having the query emits 2 `SELECT` statements - when callling `all()`, we skip the first `SELECT` statement - and jump right before the `selectin` loader is called. - To accomplish this, we have to construct objects that are - normally built in the first part of the query in order - to call directly `SelectInLoader._load_for_path`. - - TODO Move this logic to a util in the SQLAlchemy repo as per - SQLAlchemy's main maitainer suggestion. - See https://git.io/JewQ7 - """ - child_mapper = relationship_prop.mapper - parent_mapper = relationship_prop.parent - session = Session.object_session(parents[0]) - - # These issues are very unlikely to happen in practice... - for parent in parents: - # assert parent.__mapper__ is parent_mapper - # All instances must share the same session - assert session is Session.object_session(parent) - # The behavior of `selectin` is undefined if the parent is dirty - assert parent not in session.dirty - - # Should the boolean be set to False? Does it matter for our purposes? - states = [(sqlalchemy.inspect(parent), True) for parent in parents] - - # For our purposes, the query_context will only used to get the session - query_context = None - if is_sqlalchemy_version_less_than('1.4'): - query_context = QueryContext(session.query(parent_mapper.entity)) - else: - parent_mapper_query = session.query(parent_mapper.entity) - query_context = parent_mapper_query._compile_context() - - if is_sqlalchemy_version_less_than('1.4'): - selectin_loader._load_for_path( - query_context, - parent_mapper._path_registry, - states, - None, - child_mapper - ) - else: - selectin_loader._load_for_path( - query_context, - parent_mapper._path_registry, - states, - None, - child_mapper, - None - ) - - return [getattr(parent, relationship_prop.key) for parent in parents] - - loader = RelationshipLoader() + """Get the resolve function for the given relationship.""" + + def _get_loader(relationship_prop): + """Retrieve the cached loader of the given relationship.""" + loader = RELATIONSHIP_LOADERS_CACHE.get(relationship_prop, None) + if loader is None or loader.loop != get_event_loop(): + selectin_loader = strategies.SelectInLoader( + relationship_prop, (('lazy', 'selectin'),) + ) + loader = RelationshipLoader( + relationship_prop=relationship_prop, + selectin_loader=selectin_loader, + ) + RELATIONSHIP_LOADERS_CACHE[relationship_prop] = loader + return loader + + loader = _get_loader(relationship_prop) async def resolve(root, info, **args): return await loader.load(root) diff --git a/graphene_sqlalchemy/enums.py b/graphene_sqlalchemy/enums.py index a2ed17ad..19f40b7f 100644 --- a/graphene_sqlalchemy/enums.py +++ b/graphene_sqlalchemy/enums.py @@ -144,9 +144,9 @@ def sort_enum_for_object_type( column = orm_field.columns[0] if only_indexed and not (column.primary_key or column.index): continue - asc_name = get_name(column.key, True) + asc_name = get_name(field_name, True) asc_value = EnumValue(asc_name, column.asc()) - desc_name = get_name(column.key, False) + desc_name = get_name(field_name, False) desc_value = EnumValue(desc_name, column.desc()) if column.primary_key: default.append(asc_value) diff --git a/graphene_sqlalchemy/fields.py b/graphene_sqlalchemy/fields.py index d7a83392..9b4b8436 100644 --- a/graphene_sqlalchemy/fields.py +++ b/graphene_sqlalchemy/fields.py @@ -14,7 +14,7 @@ from .utils import EnumValue, get_query -class UnsortedSQLAlchemyConnectionField(ConnectionField): +class SQLAlchemyConnectionField(ConnectionField): @property def type(self): from .types import SQLAlchemyObjectType @@ -37,13 +37,45 @@ def type(self): ) return nullable_type.connection + def __init__(self, type_, *args, **kwargs): + nullable_type = get_nullable_type(type_) + if "sort" not in kwargs and nullable_type and issubclass(nullable_type, Connection): + # Let super class raise if type is not a Connection + try: + kwargs.setdefault("sort", nullable_type.Edge.node._type.sort_argument()) + except (AttributeError, TypeError): + raise TypeError( + 'Cannot create sort argument for {}. A model is required. Set the "sort" argument' + " to None to disabling the creation of the sort query argument".format( + nullable_type.__name__ + ) + ) + elif "sort" in kwargs and kwargs["sort"] is None: + del kwargs["sort"] + super(SQLAlchemyConnectionField, self).__init__(type_, *args, **kwargs) + @property def model(self): return get_nullable_type(self.type)._meta.node._meta.model @classmethod - def get_query(cls, model, info, **args): - return get_query(model, info.context) + def get_query(cls, model, info, sort=None, **args): + query = get_query(model, info.context) + if sort is not None: + if not isinstance(sort, list): + sort = [sort] + sort_args = [] + # ensure consistent handling of graphene Enums, enum values and + # plain strings + for item in sort: + if isinstance(item, enum.Enum): + sort_args.append(item.value.value) + elif isinstance(item, EnumValue): + sort_args.append(item.value) + else: + sort_args.append(item) + query = query.order_by(*sort_args) + return query @classmethod def resolve_connection(cls, connection_type, model, info, args, resolved): @@ -90,59 +122,49 @@ def wrap_resolve(self, parent_resolver): ) -# TODO Rename this to SortableSQLAlchemyConnectionField -class SQLAlchemyConnectionField(UnsortedSQLAlchemyConnectionField): +# TODO Remove in next major version +class UnsortedSQLAlchemyConnectionField(SQLAlchemyConnectionField): def __init__(self, type_, *args, **kwargs): - nullable_type = get_nullable_type(type_) - if "sort" not in kwargs and issubclass(nullable_type, Connection): - # Let super class raise if type is not a Connection - try: - kwargs.setdefault("sort", nullable_type.Edge.node._type.sort_argument()) - except (AttributeError, TypeError): - raise TypeError( - 'Cannot create sort argument for {}. A model is required. Set the "sort" argument' - " to None to disabling the creation of the sort query argument".format( - nullable_type.__name__ - ) - ) - elif "sort" in kwargs and kwargs["sort"] is None: - del kwargs["sort"] - super(SQLAlchemyConnectionField, self).__init__(type_, *args, **kwargs) - - @classmethod - def get_query(cls, model, info, sort=None, **args): - query = get_query(model, info.context) - if sort is not None: - if not isinstance(sort, list): - sort = [sort] - sort_args = [] - # ensure consistent handling of graphene Enums, enum values and - # plain strings - for item in sort: - if isinstance(item, enum.Enum): - sort_args.append(item.value.value) - elif isinstance(item, EnumValue): - sort_args.append(item.value) - else: - sort_args.append(item) - query = query.order_by(*sort_args) - return query + if "sort" in kwargs and kwargs["sort"] is not None: + warnings.warn( + "UnsortedSQLAlchemyConnectionField does not support sorting. " + "All sorting arguments will be ignored." + ) + kwargs["sort"] = None + warnings.warn( + "UnsortedSQLAlchemyConnectionField is deprecated and will be removed in the next " + "major version. Use SQLAlchemyConnectionField instead and either don't " + "provide the `sort` argument or set it to None if you do not want sorting.", + DeprecationWarning, + ) + super(UnsortedSQLAlchemyConnectionField, self).__init__(type_, *args, **kwargs) -class BatchSQLAlchemyConnectionField(UnsortedSQLAlchemyConnectionField): +class BatchSQLAlchemyConnectionField(SQLAlchemyConnectionField): """ This is currently experimental. The API and behavior may change in future versions. Use at your own risk. """ - def wrap_resolve(self, parent_resolver): - return partial( - self.connection_resolver, - self.resolver, - get_nullable_type(self.type), - self.model, - ) + @classmethod + def connection_resolver(cls, resolver, connection_type, model, root, info, **args): + if root is None: + resolved = resolver(root, info, **args) + on_resolve = partial(cls.resolve_connection, connection_type, model, info, args) + else: + relationship_prop = None + for relationship in root.__class__.__mapper__.relationships: + if relationship.mapper.class_ == model: + relationship_prop = relationship + break + resolved = get_batch_resolver(relationship_prop)(root, info, **args) + on_resolve = partial(cls.resolve_connection, connection_type, root, info, args) + + if is_thenable(resolved): + return Promise.resolve(resolved).then(on_resolve) + + return on_resolve(resolved) @classmethod def from_relationship(cls, relationship, registry, **field_kwargs): diff --git a/graphene_sqlalchemy/tests/models.py b/graphene_sqlalchemy/tests/models.py index dc399ee0..c7a1d664 100644 --- a/graphene_sqlalchemy/tests/models.py +++ b/graphene_sqlalchemy/tests/models.py @@ -110,6 +110,24 @@ class Article(Base): headline = Column(String(100)) pub_date = Column(Date()) reporter_id = Column(Integer(), ForeignKey("reporters.id")) + readers = relationship( + "Reader", secondary="articles_readers", back_populates="articles" + ) + + +class Reader(Base): + __tablename__ = "readers" + id = Column(Integer(), primary_key=True) + name = Column(String(100)) + articles = relationship( + "Article", secondary="articles_readers", back_populates="readers" + ) + + +class ArticleReader(Base): + __tablename__ = "articles_readers" + article_id = Column(Integer(), ForeignKey("articles.id"), primary_key=True) + reader_id = Column(Integer(), ForeignKey("readers.id"), primary_key=True) class ReflectedEditor(type): diff --git a/graphene_sqlalchemy/tests/test_batching.py b/graphene_sqlalchemy/tests/test_batching.py index 1896900b..fc4e6649 100644 --- a/graphene_sqlalchemy/tests/test_batching.py +++ b/graphene_sqlalchemy/tests/test_batching.py @@ -5,13 +5,13 @@ import pytest import graphene -from graphene import relay +from graphene import Connection, relay from ..fields import (BatchSQLAlchemyConnectionField, default_connection_field_factory) from ..types import ORMField, SQLAlchemyObjectType from ..utils import is_sqlalchemy_version_less_than -from .models import Article, HairKind, Pet, Reporter +from .models import Article, HairKind, Pet, Reader, Reporter from .utils import remove_cache_miss_stat, to_std_dicts @@ -73,6 +73,40 @@ def resolve_reporters(self, info): return graphene.Schema(query=Query) +def get_full_relay_schema(): + class ReporterType(SQLAlchemyObjectType): + class Meta: + model = Reporter + name = "Reporter" + interfaces = (relay.Node,) + batching = True + connection_class = Connection + + class ArticleType(SQLAlchemyObjectType): + class Meta: + model = Article + name = "Article" + interfaces = (relay.Node,) + batching = True + connection_class = Connection + + class ReaderType(SQLAlchemyObjectType): + class Meta: + model = Reader + name = "Reader" + interfaces = (relay.Node,) + batching = True + connection_class = Connection + + class Query(graphene.ObjectType): + node = relay.Node.Field() + articles = BatchSQLAlchemyConnectionField(ArticleType.connection) + reporters = BatchSQLAlchemyConnectionField(ReporterType.connection) + readers = BatchSQLAlchemyConnectionField(ReaderType.connection) + + return graphene.Schema(query=Query) + + if is_sqlalchemy_version_less_than('1.2'): pytest.skip('SQL batching only works for SQLAlchemy 1.2+', allow_module_level=True) @@ -82,11 +116,11 @@ async def test_many_to_one(session_factory): session = session_factory() reporter_1 = Reporter( - first_name='Reporter_1', + first_name='Reporter_1', ) session.add(reporter_1) reporter_2 = Reporter( - first_name='Reporter_2', + first_name='Reporter_2', ) session.add(reporter_2) @@ -138,20 +172,20 @@ async def test_many_to_one(session_factory): assert not result.errors result = to_std_dicts(result.data) assert result == { - "articles": [ - { - "headline": "Article_1", - "reporter": { - "firstName": "Reporter_1", - }, - }, - { - "headline": "Article_2", - "reporter": { - "firstName": "Reporter_2", - }, - }, - ], + "articles": [ + { + "headline": "Article_1", + "reporter": { + "firstName": "Reporter_1", + }, + }, + { + "headline": "Article_2", + "reporter": { + "firstName": "Reporter_2", + }, + }, + ], } @@ -160,11 +194,11 @@ async def test_one_to_one(session_factory): session = session_factory() reporter_1 = Reporter( - first_name='Reporter_1', + first_name='Reporter_1', ) session.add(reporter_1) reporter_2 = Reporter( - first_name='Reporter_2', + first_name='Reporter_2', ) session.add(reporter_2) @@ -185,14 +219,14 @@ async def test_one_to_one(session_factory): # Starts new session to fully reset the engine / connection logging level session = session_factory() result = await schema.execute_async(""" - query { - reporters { - firstName - favoriteArticle { - headline - } + query { + reporters { + firstName + favoriteArticle { + headline + } + } } - } """, context_value={"session": session}) messages = sqlalchemy_logging_handler.messages @@ -216,20 +250,20 @@ async def test_one_to_one(session_factory): assert not result.errors result = to_std_dicts(result.data) assert result == { - "reporters": [ - { - "firstName": "Reporter_1", - "favoriteArticle": { - "headline": "Article_1", - }, - }, - { - "firstName": "Reporter_2", - "favoriteArticle": { - "headline": "Article_2", - }, - }, - ], + "reporters": [ + { + "firstName": "Reporter_1", + "favoriteArticle": { + "headline": "Article_1", + }, + }, + { + "firstName": "Reporter_2", + "favoriteArticle": { + "headline": "Article_2", + }, + }, + ], } @@ -238,11 +272,11 @@ async def test_one_to_many(session_factory): session = session_factory() reporter_1 = Reporter( - first_name='Reporter_1', + first_name='Reporter_1', ) session.add(reporter_1) reporter_2 = Reporter( - first_name='Reporter_2', + first_name='Reporter_2', ) session.add(reporter_2) @@ -271,18 +305,18 @@ async def test_one_to_many(session_factory): # Starts new session to fully reset the engine / connection logging level session = session_factory() result = await schema.execute_async(""" - query { - reporters { - firstName - articles(first: 2) { - edges { - node { - headline - } + query { + reporters { + firstName + articles(first: 2) { + edges { + node { + headline + } + } + } } - } } - } """, context_value={"session": session}) messages = sqlalchemy_logging_handler.messages @@ -306,42 +340,42 @@ async def test_one_to_many(session_factory): assert not result.errors result = to_std_dicts(result.data) assert result == { - "reporters": [ - { - "firstName": "Reporter_1", - "articles": { - "edges": [ - { - "node": { - "headline": "Article_1", - }, - }, - { - "node": { - "headline": "Article_2", - }, - }, - ], - }, - }, - { - "firstName": "Reporter_2", - "articles": { - "edges": [ - { - "node": { - "headline": "Article_3", + "reporters": [ + { + "firstName": "Reporter_1", + "articles": { + "edges": [ + { + "node": { + "headline": "Article_1", + }, + }, + { + "node": { + "headline": "Article_2", + }, + }, + ], }, - }, - { - "node": { - "headline": "Article_4", + }, + { + "firstName": "Reporter_2", + "articles": { + "edges": [ + { + "node": { + "headline": "Article_3", + }, + }, + { + "node": { + "headline": "Article_4", + }, + }, + ], }, - }, - ], - }, - }, - ], + }, + ], } @@ -350,11 +384,11 @@ async def test_many_to_many(session_factory): session = session_factory() reporter_1 = Reporter( - first_name='Reporter_1', + first_name='Reporter_1', ) session.add(reporter_1) reporter_2 = Reporter( - first_name='Reporter_2', + first_name='Reporter_2', ) session.add(reporter_2) @@ -385,18 +419,18 @@ async def test_many_to_many(session_factory): # Starts new session to fully reset the engine / connection logging level session = session_factory() result = await schema.execute_async(""" - query { - reporters { - firstName - pets(first: 2) { - edges { - node { - name - } + query { + reporters { + firstName + pets(first: 2) { + edges { + node { + name + } + } + } } - } } - } """, context_value={"session": session}) messages = sqlalchemy_logging_handler.messages @@ -420,42 +454,42 @@ async def test_many_to_many(session_factory): assert not result.errors result = to_std_dicts(result.data) assert result == { - "reporters": [ - { - "firstName": "Reporter_1", - "pets": { - "edges": [ - { - "node": { - "name": "Pet_1", - }, - }, - { - "node": { - "name": "Pet_2", + "reporters": [ + { + "firstName": "Reporter_1", + "pets": { + "edges": [ + { + "node": { + "name": "Pet_1", + }, + }, + { + "node": { + "name": "Pet_2", + }, + }, + ], }, - }, - ], - }, - }, - { - "firstName": "Reporter_2", - "pets": { - "edges": [ - { - "node": { - "name": "Pet_3", + }, + { + "firstName": "Reporter_2", + "pets": { + "edges": [ + { + "node": { + "name": "Pet_3", + }, + }, + { + "node": { + "name": "Pet_4", + }, + }, + ], }, - }, - { - "node": { - "name": "Pet_4", - }, - }, - ], - }, - }, - ], + }, + ], } @@ -531,6 +565,70 @@ def resolve_reporters(self, info): assert len(select_statements) == 2 +@pytest.mark.asyncio +def test_batch_sorting_with_custom_ormfield(session_factory): + session = session_factory() + reporter_1 = Reporter(first_name='Reporter_1') + session.add(reporter_1) + reporter_2 = Reporter(first_name='Reporter_2') + session.add(reporter_2) + session.commit() + session.close() + + class ReporterType(SQLAlchemyObjectType): + class Meta: + model = Reporter + name = "Reporter" + interfaces = (relay.Node,) + batching = True + connection_class = Connection + + firstname = ORMField(model_attr="first_name") + + class Query(graphene.ObjectType): + node = relay.Node.Field() + reporters = BatchSQLAlchemyConnectionField(ReporterType.connection) + + class ReporterType(SQLAlchemyObjectType): + class Meta: + model = Reporter + interfaces = (relay.Node,) + batching = True + + schema = graphene.Schema(query=Query) + + # Test one-to-one and many-to-one relationships + with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler: + # Starts new session to fully reset the engine / connection logging level + session = session_factory() + result = schema.execute(""" + query { + reporters(sort: [FIRSTNAME_DESC]) { + edges { + node { + firstname + } + } + } + } + """, context_value={"session": session}) + messages = sqlalchemy_logging_handler.messages + + result = to_std_dicts(result.data) + assert result == { + "reporters": {"edges": [ + {"node": { + "firstname": "Reporter_2", + }}, + {"node": { + "firstname": "Reporter_1", + }}, + ]} + } + select_statements = [message for message in messages if 'SELECT' in message and 'FROM reporters' in message] + assert len(select_statements) == 2 + + @pytest.mark.asyncio async def test_connection_factory_field_overrides_batching_is_false(session_factory): session = session_factory() @@ -642,3 +740,106 @@ def resolve_reporters(self, info): select_statements = [message for message in messages if 'SELECT' in message and 'FROM articles' in message] assert len(select_statements) == 2 + + +@pytest.mark.asyncio +async def test_batching_across_nested_relay_schema(session_factory): + session = session_factory() + + for first_name in "fgerbhjikzutzxsdfdqqa": + reporter = Reporter( + first_name=first_name, + ) + session.add(reporter) + article = Article(headline='Article') + article.reporter = reporter + session.add(article) + reader = Reader(name='Reader') + reader.articles = [article] + session.add(reader) + + session.commit() + session.close() + + schema = get_full_relay_schema() + + with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler: + # Starts new session to fully reset the engine / connection logging level + session = session_factory() + result = await schema.execute_async(""" + query { + reporters { + edges { + node { + firstName + articles { + edges { + node { + id + readers { + edges { + node { + name + } + } + } + } + } + } + } + } + } + } + """, context_value={"session": session}) + messages = sqlalchemy_logging_handler.messages + + result = to_std_dicts(result.data) + select_statements = [message for message in messages if 'SELECT' in message] + assert len(select_statements) == 4 + assert select_statements[-1].startswith("SELECT articles_1.id") + if is_sqlalchemy_version_less_than('1.3'): + assert select_statements[-2].startswith("SELECT reporters_1.id") + assert "WHERE reporters_1.id IN" in select_statements[-2] + else: + assert select_statements[-2].startswith("SELECT articles.reporter_id") + assert "WHERE articles.reporter_id IN" in select_statements[-2] + + +@pytest.mark.asyncio +async def test_sorting_can_be_used_with_batching_when_using_full_relay(session_factory): + session = session_factory() + + for first_name, email in zip("cadbbb", "aaabac"): + reporter_1 = Reporter( + first_name=first_name, + email=email + ) + session.add(reporter_1) + article_1 = Article(headline="headline") + article_1.reporter = reporter_1 + session.add(article_1) + + session.commit() + session.close() + + schema = get_full_relay_schema() + + session = session_factory() + result = await schema.execute_async(""" + query { + reporters(sort: [FIRST_NAME_ASC, EMAIL_ASC]) { + edges { + node { + firstName + email + } + } + } + } + """, context_value={"session": session}) + + result = to_std_dicts(result.data) + assert [ + r["node"]["firstName"] + r["node"]["email"] + for r in result["reporters"]["edges"] + ] == ['aa', 'ba', 'bb', 'bc', 'ca', 'da'] diff --git a/graphene_sqlalchemy/tests/test_fields.py b/graphene_sqlalchemy/tests/test_fields.py index 357055e3..2782da89 100644 --- a/graphene_sqlalchemy/tests/test_fields.py +++ b/graphene_sqlalchemy/tests/test_fields.py @@ -64,6 +64,14 @@ def test_type_assert_object_has_connection(): ## +def test_unsorted_connection_field_removes_sort_arg_if_passed(): + editor = UnsortedSQLAlchemyConnectionField( + Editor.connection, + sort=Editor.sort_argument(has_default=True) + ) + assert "sort" not in editor.args + + def test_sort_added_by_default(): field = SQLAlchemyConnectionField(Pet.connection) assert "sort" in field.args