Skip to content

Commit

Permalink
update enum processing
Browse files Browse the repository at this point in the history
  • Loading branch information
art1415926535 committed Jan 29, 2021
1 parent aa0c992 commit 05e58de
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 65 deletions.
77 changes: 12 additions & 65 deletions graphene_sqlalchemy_filter/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import warnings
from copy import deepcopy
from functools import lru_cache
from inspect import isfunction

# GraphQL
import graphene
Expand Down Expand Up @@ -63,7 +62,10 @@ def _get_class(obj: 'GRAPHENE_OBJECT_OR_CLASS') -> 'Type[graphene.ObjectType]':
if inspect.isclass(obj):
return obj

return obj.__class__ # only graphene-sqlalchemy<=2.2.0; pragma: no cover
if isinstance(obj, graphene.Field): # only graphene-sqlalchemy==2.2.0
return obj.type

return obj.__class__ # only graphene-sqlalchemy<2.2.0; pragma: no cover


def _eq_filter(field: 'Column', value: 'Any') -> 'Any':
Expand Down Expand Up @@ -571,11 +573,10 @@ def _get_gql_type_from_sqla_type(
return GenericScalar
else:
_type = convert_sqlalchemy_type(column_type, sqla_column)
if isfunction(_type):
return _type()
if inspect.isfunction(_type):
return _type() # only graphene-sqlalchemy>2.2.0
return _type


@classmethod
def _get_model_fields_data(cls, model, only_fields: 'Iterable[str]'):
"""
Expand Down Expand Up @@ -620,62 +621,6 @@ def _get_model_fields_data(cls, model, only_fields: 'Iterable[str]'):

return model_fields

@staticmethod
def _is_graphene_enum(obj: 'Any') -> bool:
"""
Return whether 'obj' is a enum.
Args:
obj: lambda or graphene.Field
Returns:
boolean
"""
if gqls_version < (2, 2, 0):
# https://github.com/graphql-python/graphene-sqlalchemy/blob/v2.1.2/graphene_sqlalchemy/converter.py#L147
return isinstance(
obj, graphene.Field
) and isinstance( # pragma: no cover
obj._type, graphene.types.enum.EnumMeta
)
elif gqls_version == (2, 2, 0):
# https://github.com/graphql-python/graphene-sqlalchemy/blob/db3e9f4c3baad3e62c113d4a9ddd2e3983d324f2/graphene_sqlalchemy/converter.py#L150
return isinstance(obj, graphene.Field) and callable(
obj._type
) # pragma: no cover
else:
# https://github.com/graphql-python/graphene-sqlalchemy/blob/17d535efba03070cbc505d915673e0f24d9ca60c/graphene_sqlalchemy/converter.py#L216
return callable(obj) and obj.__name__ == '<lambda>'

@staticmethod
def _get_enum_from_field(
enum: 'Union[Callable, graphene.Field]',
) -> graphene.Enum:
"""
Get graphene enum.
Args:
enum: lambda or graphene.Field
Returns:
Graphene enum.
"""
if gqls_version < (2, 2, 0):
# AssertionError: Found different types
# with the same name in the schema: ...
raise AssertionError( # pragma: no cover
'Enum is not supported. '
'Requires graphene-sqlalchemy 2.2.0 or higher.'
)
elif gqls_version == (2, 2, 0):
# https://github.com/graphql-python/graphene-sqlalchemy/compare/2.1.2...2.2.0#diff-9202780f6bf4790a0d960de553c086f1L155
return enum._type()() # pragma: no cover
else:
# https://github.com/graphql-python/graphene-sqlalchemy/compare/2.2.0...2.2.1#diff-9202780f6bf4790a0d960de553c086f1L150
return enum()()

@classmethod
def _generate_filter_fields(
cls,
Expand Down Expand Up @@ -713,8 +658,6 @@ def _generate_filter_fields(
except KeyError:
if isinstance(field_type, graphene.List):
filter_field = field_type
elif cls._is_graphene_enum(field_type):
filter_field = cls._get_enum_from_field(field_type)
else:
field_type = _get_class(field_type)
filter_field = field_type(description=doc)
Expand Down Expand Up @@ -840,8 +783,12 @@ def _translate_filter(
raise KeyError('Field not found: ' + field)

model_field_type = getattr(model_field, 'type', None)
if isinstance(model_field_type, sqltypes.Enum) and model_field_type.enum_class:
value = model_field_type.enum_class(value)
is_enum = isinstance(model_field_type, sqltypes.Enum)
if is_enum and model_field_type.enum_class:
if isinstance(value, list):
value = [model_field_type.enum_class(v) for v in value]
else:
value = model_field_type.enum_class(value)

clause = filter_function(model_field, value)
return query, clause
Expand Down
2 changes: 2 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,5 @@ show_missing = True

exclude_lines =
pragma: no cover
only graphene-sqlalchemy==2.2.0
only graphene-sqlalchemy<2.2.0
56 changes: 56 additions & 0 deletions tests/test_enum_without_relay.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Third Party
import pytest

# GraphQL
import graphene
from graphene_sqlalchemy import get_query

# Project
from graphene_sqlalchemy_filter import FilterSet
from tests import gqls_version, models


@pytest.mark.skipif(gqls_version < (2, 2, 0), reason='not supported')
def test_enum_filter_without_relay(session):
"""https://github.com/art1415926535/graphene-sqlalchemy-filter/issues/28"""
online = models.StatusEnum.online
users = [
models.User(username='user_1', is_active=True, status=online),
models.User(username='user_2', is_active=True),
]
session.bulk_save_objects(users)

class UserFilter(FilterSet):
class Meta:
model = models.User
fields = {'status': ['eq', 'in']}

class UserType(graphene.ObjectType):
username = graphene.String()

class Query(graphene.ObjectType):
all_users = graphene.List(UserType, filters=UserFilter())

def resolve_all_users(self, info, filters=None):
query = get_query(models.User, info.context)
if filters is not None:
query = UserFilter.filter(info, query, filters)

return query

schema = graphene.Schema(query=Query)

request_string = """
{
allUsers(filters: {status: ONLINE, statusIn: [ONLINE]}) {
username
}
}"""
execution_result = schema.execute(
request_string, context={'session': session}
)

assert not execution_result.errors
assert not execution_result.invalid

assert execution_result.data == {'allUsers': [{'username': 'user_1'}]}

0 comments on commit 05e58de

Please sign in to comment.