Skip to content

Commit

Permalink
Add batching params (#260)
Browse files Browse the repository at this point in the history
Add parameters to toggle batching on or off. This can be configured at 2 levels:
- we can configure all the fields of a type at once via SQLAlchemyObjectType.meta.batching
- or we can specify it for a specific field via ORMfield.batching. This trumps SQLAlchemyObjectType.meta.batching.
  • Loading branch information
jnak authored Feb 12, 2020
1 parent 7a48d3d commit 17d535e
Show file tree
Hide file tree
Showing 6 changed files with 325 additions and 120 deletions.
96 changes: 76 additions & 20 deletions graphene_sqlalchemy/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,28 @@
from singledispatch import singledispatch
from sqlalchemy import types
from sqlalchemy.dialects import postgresql
from sqlalchemy.orm import interfaces
from sqlalchemy.orm import interfaces, strategies

from graphene import (ID, Boolean, Dynamic, Enum, Field, Float, Int, List,
String)
from graphene.types.json import JSONString

from .batching import get_batch_resolver
from .enums import enum_for_sa_enum
from .fields import (BatchSQLAlchemyConnectionField,
default_connection_field_factory)
from .registry import get_global_registry
from .resolvers import get_attr_resolver, get_custom_resolver

try:
from sqlalchemy_utils import ChoiceType, JSONType, ScalarListType, TSVectorType
except ImportError:
ChoiceType = JSONType = ScalarListType = TSVectorType = object


is_selectin_available = getattr(strategies, 'SelectInLoader', None)


def get_column_doc(column):
return getattr(column, "doc", None)

Expand All @@ -26,33 +33,82 @@ def is_column_nullable(column):
return bool(getattr(column, "nullable", True))


def convert_sqlalchemy_relationship(relationship_prop, registry, connection_field_factory, resolver, **field_kwargs):
direction = relationship_prop.direction
model = relationship_prop.mapper.entity

def convert_sqlalchemy_relationship(relationship_prop, obj_type, connection_field_factory, batching,
orm_field_name, **field_kwargs):
"""
:param sqlalchemy.RelationshipProperty relationship_prop:
:param SQLAlchemyObjectType obj_type:
:param function|None connection_field_factory:
:param bool batching:
:param str orm_field_name:
:param dict field_kwargs:
:rtype: Dynamic
"""
def dynamic_type():
_type = registry.get_type_for_model(model)
""":rtype: Field|None"""
direction = relationship_prop.direction
child_type = obj_type._meta.registry.get_type_for_model(relationship_prop.mapper.entity)
batching_ = batching if is_selectin_available else False

if not _type:
if not child_type:
return None

if direction == interfaces.MANYTOONE or not relationship_prop.uselist:
return Field(
_type,
resolver=resolver,
**field_kwargs
)
elif direction in (interfaces.ONETOMANY, interfaces.MANYTOMANY):
if _type.connection:
# TODO Add a way to override connection_field_factory
return connection_field_factory(relationship_prop, registry, **field_kwargs)
return Field(
List(_type),
**field_kwargs
)
return _convert_o2o_or_m2o_relationship(relationship_prop, obj_type, batching_, orm_field_name,
**field_kwargs)

if direction in (interfaces.ONETOMANY, interfaces.MANYTOMANY):
return _convert_o2m_or_m2m_relationship(relationship_prop, obj_type, batching_,
connection_field_factory, **field_kwargs)

return Dynamic(dynamic_type)


def _convert_o2o_or_m2o_relationship(relationship_prop, obj_type, batching, orm_field_name, **field_kwargs):
"""
Convert one-to-one or many-to-one relationshsip. Return an object field.
:param sqlalchemy.RelationshipProperty relationship_prop:
:param SQLAlchemyObjectType obj_type:
:param bool batching:
:param str orm_field_name:
:param dict field_kwargs:
:rtype: Field
"""
child_type = obj_type._meta.registry.get_type_for_model(relationship_prop.mapper.entity)

resolver = get_custom_resolver(obj_type, orm_field_name)
if resolver is None:
resolver = get_batch_resolver(relationship_prop) if batching else \
get_attr_resolver(obj_type, relationship_prop.key)

return Field(child_type, resolver=resolver, **field_kwargs)


def _convert_o2m_or_m2m_relationship(relationship_prop, obj_type, batching, connection_field_factory, **field_kwargs):
"""
Convert one-to-many or many-to-many relationshsip. Return a list field or a connection field.
:param sqlalchemy.RelationshipProperty relationship_prop:
:param SQLAlchemyObjectType obj_type:
:param bool batching:
:param function|None connection_field_factory:
:param dict field_kwargs:
:rtype: Field
"""
child_type = obj_type._meta.registry.get_type_for_model(relationship_prop.mapper.entity)

if not child_type._meta.connection:
return Field(List(child_type), **field_kwargs)

# TODO Allow override of connection_field_factory and resolver via ORMField
if connection_field_factory is None:
connection_field_factory = BatchSQLAlchemyConnectionField.from_relationship if batching else \
default_connection_field_factory

return connection_field_factory(relationship_prop, obj_type._meta.registry, **field_kwargs)


def convert_sqlalchemy_hybrid_method(hybrid_prop, resolver, **field_kwargs):
if 'type' not in field_kwargs:
# TODO The default type should be dependent on the type of the property propety.
Expand Down
Empty file removed graphene_sqlalchemy/resolver.py
Empty file.
26 changes: 26 additions & 0 deletions graphene_sqlalchemy/resolvers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from graphene.utils.get_unbound_function import get_unbound_function


def get_custom_resolver(obj_type, orm_field_name):
"""
Since `graphene` will call `resolve_<field_name>` on a field only if it
does not have a `resolver`, we need to re-implement that logic here so
users are able to override the default resolvers that we provide.
"""
resolver = getattr(obj_type, 'resolve_{}'.format(orm_field_name), None)
if resolver:
return get_unbound_function(resolver)

return None


def get_attr_resolver(obj_type, model_attr):
"""
In order to support field renaming via `ORMField.model_attr`,
we need to define resolver functions for each field.
:param SQLAlchemyObjectType obj_type:
:param str model_attr: the name of the SQLAlchemy attribute
:rtype: Callable
"""
return lambda root, _info: getattr(root, model_attr, None)
195 changes: 190 additions & 5 deletions graphene_sqlalchemy/tests/test_batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
import graphene
from graphene import relay

from ..fields import BatchSQLAlchemyConnectionField
from ..types import SQLAlchemyObjectType
from ..fields import (BatchSQLAlchemyConnectionField,
default_connection_field_factory)
from ..types import ORMField, SQLAlchemyObjectType
from .models import Article, HairKind, Pet, Reporter
from .utils import is_sqlalchemy_version_less_than, to_std_dicts

Expand Down Expand Up @@ -43,19 +44,19 @@ class ReporterType(SQLAlchemyObjectType):
class Meta:
model = Reporter
interfaces = (relay.Node,)
connection_field_factory = BatchSQLAlchemyConnectionField.from_relationship
batching = True

class ArticleType(SQLAlchemyObjectType):
class Meta:
model = Article
interfaces = (relay.Node,)
connection_field_factory = BatchSQLAlchemyConnectionField.from_relationship
batching = True

class PetType(SQLAlchemyObjectType):
class Meta:
model = Pet
interfaces = (relay.Node,)
connection_field_factory = BatchSQLAlchemyConnectionField.from_relationship
batching = True

class Query(graphene.ObjectType):
articles = graphene.Field(graphene.List(ArticleType))
Expand Down Expand Up @@ -513,3 +514,187 @@ def test_many_to_many(session_factory):
},
],
}


def test_disable_batching_via_ormfield(session_factory):
session = session_factory()
reporter_1 = Reporter(first_name='Reporter_1')
session.add(reporter_1)
reporter_2 = Reporter(first_name='Reporter_2')
session.add(reporter_2)
session.commit()
session.close()

class ReporterType(SQLAlchemyObjectType):
class Meta:
model = Reporter
interfaces = (relay.Node,)
batching = True

favorite_article = ORMField(batching=False)
articles = ORMField(batching=False)

class ArticleType(SQLAlchemyObjectType):
class Meta:
model = Article
interfaces = (relay.Node,)

class Query(graphene.ObjectType):
reporters = graphene.Field(graphene.List(ReporterType))

def resolve_reporters(self, info):
return info.context.get('session').query(Reporter).all()

schema = graphene.Schema(query=Query)

# Test one-to-one and many-to-one relationships
with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler:
# Starts new session to fully reset the engine / connection logging level
session = session_factory()
schema.execute("""
query {
reporters {
favoriteArticle {
headline
}
}
}
""", context_value={"session": session})
messages = sqlalchemy_logging_handler.messages

select_statements = [message for message in messages if 'SELECT' in message and 'FROM articles' in message]
assert len(select_statements) == 2

# Test one-to-many and many-to-many relationships
with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler:
# Starts new session to fully reset the engine / connection logging level
session = session_factory()
schema.execute("""
query {
reporters {
articles {
edges {
node {
headline
}
}
}
}
}
""", context_value={"session": session})
messages = sqlalchemy_logging_handler.messages

select_statements = [message for message in messages if 'SELECT' in message and 'FROM articles' in message]
assert len(select_statements) == 2


def test_connection_factory_field_overrides_batching_is_false(session_factory):
session = session_factory()
reporter_1 = Reporter(first_name='Reporter_1')
session.add(reporter_1)
reporter_2 = Reporter(first_name='Reporter_2')
session.add(reporter_2)
session.commit()
session.close()

class ReporterType(SQLAlchemyObjectType):
class Meta:
model = Reporter
interfaces = (relay.Node,)
batching = False
connection_field_factory = BatchSQLAlchemyConnectionField.from_relationship

articles = ORMField(batching=False)

class ArticleType(SQLAlchemyObjectType):
class Meta:
model = Article
interfaces = (relay.Node,)

class Query(graphene.ObjectType):
reporters = graphene.Field(graphene.List(ReporterType))

def resolve_reporters(self, info):
return info.context.get('session').query(Reporter).all()

schema = graphene.Schema(query=Query)

with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler:
# Starts new session to fully reset the engine / connection logging level
session = session_factory()
schema.execute("""
query {
reporters {
articles {
edges {
node {
headline
}
}
}
}
}
""", context_value={"session": session})
messages = sqlalchemy_logging_handler.messages

if is_sqlalchemy_version_less_than('1.3'):
# The batched SQL statement generated is different in 1.2.x
# SQLAlchemy 1.3+ optimizes out a JOIN statement in `selectin`
# See https://git.io/JewQu
select_statements = [message for message in messages if 'SELECT' in message and 'JOIN articles' in message]
else:
select_statements = [message for message in messages if 'SELECT' in message and 'FROM articles' in message]
assert len(select_statements) == 1


def test_connection_factory_field_overrides_batching_is_true(session_factory):
session = session_factory()
reporter_1 = Reporter(first_name='Reporter_1')
session.add(reporter_1)
reporter_2 = Reporter(first_name='Reporter_2')
session.add(reporter_2)
session.commit()
session.close()

class ReporterType(SQLAlchemyObjectType):
class Meta:
model = Reporter
interfaces = (relay.Node,)
batching = True
connection_field_factory = default_connection_field_factory

articles = ORMField(batching=True)

class ArticleType(SQLAlchemyObjectType):
class Meta:
model = Article
interfaces = (relay.Node,)

class Query(graphene.ObjectType):
reporters = graphene.Field(graphene.List(ReporterType))

def resolve_reporters(self, info):
return info.context.get('session').query(Reporter).all()

schema = graphene.Schema(query=Query)

with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler:
# Starts new session to fully reset the engine / connection logging level
session = session_factory()
schema.execute("""
query {
reporters {
articles {
edges {
node {
headline
}
}
}
}
}
""", context_value={"session": session})
messages = sqlalchemy_logging_handler.messages

select_statements = [message for message in messages if 'SELECT' in message and 'FROM articles' in message]
assert len(select_statements) == 2
Loading

0 comments on commit 17d535e

Please sign in to comment.