diff --git a/graphene_sqlalchemy/converter.py b/graphene_sqlalchemy/converter.py index 4ff55eed..ae90001b 100644 --- a/graphene_sqlalchemy/converter.py +++ b/graphene_sqlalchemy/converter.py @@ -42,7 +42,7 @@ def dynamic_type(): **field_kwargs ) elif direction in (interfaces.ONETOMANY, interfaces.MANYTOMANY): - if _type._meta.connection: + if _type.connection: # TODO Add a way to override connection_field_factory return connection_field_factory(relationship_prop, registry, **field_kwargs) return Field( diff --git a/graphene_sqlalchemy/fields.py b/graphene_sqlalchemy/fields.py index 840204ae..254319f9 100644 --- a/graphene_sqlalchemy/fields.py +++ b/graphene_sqlalchemy/fields.py @@ -24,10 +24,10 @@ def type(self): assert issubclass(_type, SQLAlchemyObjectType), ( "SQLALchemyConnectionField only accepts SQLAlchemyObjectType types, not {}" ).format(_type.__name__) - assert _type._meta.connection, "The type {} doesn't have a connection".format( + assert _type.connection, "The type {} doesn't have a connection".format( _type.__name__ ) - return _type._meta.connection + return _type.connection @property def model(self): @@ -115,7 +115,7 @@ def get_resolver(self, parent_resolver): def from_relationship(cls, relationship, registry, **field_kwargs): model = relationship.mapper.entity model_type = registry.get_type_for_model(model) - return cls(model_type._meta.connection, resolver=get_batch_resolver(relationship), **field_kwargs) + return cls(model_type.connection, resolver=get_batch_resolver(relationship), **field_kwargs) def default_connection_field_factory(relationship, registry, **field_kwargs): diff --git a/graphene_sqlalchemy/tests/test_fields.py b/graphene_sqlalchemy/tests/test_fields.py index 557ff114..9ed3c4aa 100644 --- a/graphene_sqlalchemy/tests/test_fields.py +++ b/graphene_sqlalchemy/tests/test_fields.py @@ -31,7 +31,7 @@ def resolver(_obj, _info): return Promise.resolve([]) result = UnsortedSQLAlchemyConnectionField.connection_resolver( - resolver, Pet._meta.connection, Pet, None, None + resolver, Pet.connection, Pet, None, None ) assert isinstance(result, Promise) @@ -51,18 +51,18 @@ def test_type_assert_object_has_connection(): def test_sort_added_by_default(): - field = SQLAlchemyConnectionField(Pet._meta.connection) + field = SQLAlchemyConnectionField(Pet.connection) assert "sort" in field.args assert field.args["sort"] == Pet.sort_argument() def test_sort_can_be_removed(): - field = SQLAlchemyConnectionField(Pet._meta.connection, sort=None) + field = SQLAlchemyConnectionField(Pet.connection, sort=None) assert "sort" not in field.args def test_custom_sort(): - field = SQLAlchemyConnectionField(Pet._meta.connection, sort=Editor.sort_argument()) + field = SQLAlchemyConnectionField(Pet.connection, sort=Editor.sort_argument()) assert field.args["sort"] == Editor.sort_argument() diff --git a/graphene_sqlalchemy/tests/test_types.py b/graphene_sqlalchemy/tests/test_types.py index fda8e659..bf563b6e 100644 --- a/graphene_sqlalchemy/tests/test_types.py +++ b/graphene_sqlalchemy/tests/test_types.py @@ -4,6 +4,7 @@ from graphene import (Dynamic, Field, GlobalID, Int, List, Node, NonNull, ObjectType, Schema, String) +from graphene.relay import Connection from ..converter import convert_sqlalchemy_composite from ..fields import (SQLAlchemyConnectionField, @@ -46,6 +47,15 @@ class Meta: assert reporter == reporter_node +def test_connection(): + class ReporterType(SQLAlchemyObjectType): + class Meta: + model = Reporter + interfaces = (Node,) + + assert issubclass(ReporterType.connection, Connection) + + def test_sqlalchemy_default_fields(): @convert_sqlalchemy_composite.register(CompositeFullName) def convert_composite_class(composite, registry): diff --git a/graphene_sqlalchemy/types.py b/graphene_sqlalchemy/types.py index 2ed5110e..ef189b38 100644 --- a/graphene_sqlalchemy/types.py +++ b/graphene_sqlalchemy/types.py @@ -325,6 +325,8 @@ def __init_subclass_with_meta__( _meta.connection = connection _meta.id = id or "id" + cls.connection = connection # Public way to get the connection + super(SQLAlchemyObjectType, cls).__init_subclass_with_meta__( _meta=_meta, interfaces=interfaces, **options )