diff --git a/.github/workflows/deploy.yml b/.github/workflows/deploy.yml new file mode 100644 index 00000000..1ae7b4b6 --- /dev/null +++ b/.github/workflows/deploy.yml @@ -0,0 +1,26 @@ +name: 🚀 Deploy to PyPI + +on: + push: + tags: + - '*' + +jobs: + build: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v2 + - name: Set up Python 3.9 + uses: actions/setup-python@v2 + with: + python-version: 3.9 + - name: Build wheel and source tarball + run: | + pip install wheel + python setup.py sdist bdist_wheel + - name: Publish a Python distribution to PyPI + uses: pypa/gh-action-pypi-publish@v1.1.0 + with: + user: __token__ + password: ${{ secrets.PYPI_API_TOKEN }} diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml new file mode 100644 index 00000000..559326c4 --- /dev/null +++ b/.github/workflows/lint.yml @@ -0,0 +1,22 @@ +name: Lint + +on: [push, pull_request] + +jobs: + build: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v2 + - name: Set up Python 3.9 + uses: actions/setup-python@v2 + with: + python-version: 3.9 + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install tox + - name: Run lint 💅 + run: tox + env: + TOXENV: flake8 diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml new file mode 100644 index 00000000..a9a3bd5d --- /dev/null +++ b/.github/workflows/tests.yml @@ -0,0 +1,38 @@ +name: Tests + +on: [push, pull_request] + +jobs: + test: + runs-on: ubuntu-latest + strategy: + max-parallel: 10 + matrix: + sql-alchemy: ["1.2", "1.3", "1.4"] + python-version: ["3.6", "3.7", "3.8", "3.9"] + + steps: + - uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install tox tox-gh-actions + - name: Test with tox + run: tox + env: + SQLALCHEMY: ${{ matrix.sql-alchemy }} + TOXENV: ${{ matrix.toxenv }} + - name: Upload coverage.xml + if: ${{ matrix.sql-alchemy == '1.4' && matrix.python-version == '3.9' }} + uses: actions/upload-artifact@v2 + with: + name: graphene-sqlalchemy-coverage + path: coverage.xml + if-no-files-found: error + - name: Upload coverage.xml to codecov + if: ${{ matrix.sql-alchemy == '1.4' && matrix.python-version == '3.9' }} + uses: codecov/codecov-action@v1 diff --git a/.gitignore b/.gitignore index a97b8c21..c4a735fe 100644 --- a/.gitignore +++ b/.gitignore @@ -69,3 +69,6 @@ target/ # Databases *.sqlite3 .vscode + +# mypy cache +.mypy_cache/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 136f8e7a..1c67ab03 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,7 +1,7 @@ default_language_version: python: python3.7 repos: -- repo: git://github.com/pre-commit/pre-commit-hooks +- repo: https://github.com/pre-commit/pre-commit-hooks rev: c8bad492e1b1d65d9126dba3fe3bd49a5a52b9d6 # v2.1.0 hooks: - id: check-merge-conflict @@ -11,15 +11,15 @@ repos: exclude: ^docs/.*$ - id: trailing-whitespace exclude: README.md -- repo: git://github.com/PyCQA/flake8 +- repo: https://github.com/PyCQA/flake8 rev: 88caf5ac484f5c09aedc02167c59c66ff0af0068 # 3.7.7 hooks: - id: flake8 -- repo: git://github.com/asottile/seed-isort-config +- repo: https://github.com/asottile/seed-isort-config rev: v1.7.0 hooks: - id: seed-isort-config -- repo: git://github.com/pre-commit/mirrors-isort +- repo: https://github.com/pre-commit/mirrors-isort rev: v4.3.4 hooks: - id: isort diff --git a/.travis.yml b/.travis.yml deleted file mode 100644 index 5a988428..00000000 --- a/.travis.yml +++ /dev/null @@ -1,47 +0,0 @@ -language: python -matrix: - include: - # Python 2.7 - - env: TOXENV=py27 - python: 2.7 - # Python 3.5 - - env: TOXENV=py35 - python: 3.5 - # Python 3.6 - - env: TOXENV=py36 - python: 3.6 - # Python 3.7 - - env: TOXENV=py37 - python: 3.7 - dist: xenial - # SQLAlchemy 1.1 - - env: TOXENV=py37-sql11 - python: 3.7 - dist: xenial - # SQLAlchemy 1.2 - - env: TOXENV=py37-sql12 - python: 3.7 - dist: xenial - # SQLAlchemy 1.3 - - env: TOXENV=py37-sql13 - python: 3.7 - dist: xenial - # Pre-commit - - env: TOXENV=pre-commit - python: 3.7 - dist: xenial -install: pip install .[dev] -script: tox -after_success: coveralls -cache: - directories: - - $HOME/.cache/pip - - $HOME/.cache/pre-commit -deploy: - provider: pypi - user: syrusakbary - on: - tags: true - password: - secure: q0ey31cWljGB30l43aEd1KIPuAHRutzmsd2lBb/2zvD79ReBrzvCdFAkH2xcyo4Volk3aazQQTNUIurnTuvBxmtqja0e+gUaO5LdOcokVdOGyLABXh7qhd2kdvbTDWgSwA4EWneLGXn/SjXSe0f3pCcrwc6WDcLAHxtffMvO9gulpYQtUoOqXfMipMOkRD9iDWTJBsSo3trL70X1FHOVr6Yqi0mfkX2Y/imxn6wlTWRz28Ru94xrj27OmUnCv7qcG0taO8LNlUCquNFAr2sZ+l+U/GkQrrM1y+ehPz3pmI0cCCd7SX/7+EG9ViZ07BZ31nk4pgnqjmj3nFwqnCE/4IApGnduqtrMDF63C9TnB1TU8oJmbbUCu4ODwRpBPZMnwzaHsLnrpdrB89/98NtTfujdrh3U5bVB+t33yxrXVh+FjgLYj9PVeDixpFDn6V/Xcnv4BbRMNOhXIQT7a7/5b99RiXBjCk6KRu+Jdu5DZ+3G4Nbr4oim3kZFPUHa555qbzTlwAfkrQxKv3C3OdVJR7eGc9ADsbHyEJbdPNAh/T+xblXTXLS3hPYDvgM+WEGy3CytBDG3JVcXm25ZP96EDWjweJ7MyfylubhuKj/iR1Y1wiHeIsYq9CqRrFQUWL8gFJBfmgjs96xRXXXnvyLtKUKpKw3wFg5cR/6FnLeYZ8k= - distributions: "sdist bdist_wheel" diff --git a/README.md b/README.md index 9b617069..04692973 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,7 @@ A [SQLAlchemy](http://www.sqlalchemy.org/) integration for [Graphene](http://gra ## Installation -For instaling graphene, just run this command in your shell +For installing Graphene, just run this command in your shell. ```bash pip install "graphene-sqlalchemy>=2.0" @@ -34,7 +34,7 @@ class UserModel(Base): last_name = Column(String) ``` -To create a GraphQL schema for it you simply have to write the following: +To create a GraphQL schema for it, you simply have to write the following: ```python import graphene diff --git a/graphene_sqlalchemy/__init__.py b/graphene_sqlalchemy/__init__.py index 3945d506..060bd13b 100644 --- a/graphene_sqlalchemy/__init__.py +++ b/graphene_sqlalchemy/__init__.py @@ -2,7 +2,7 @@ from .fields import SQLAlchemyConnectionField from .utils import get_query, get_session -__version__ = "2.3.0" +__version__ = "3.0.0b1" __all__ = [ "__version__", diff --git a/graphene_sqlalchemy/batching.py b/graphene_sqlalchemy/batching.py index baf01deb..85cc8855 100644 --- a/graphene_sqlalchemy/batching.py +++ b/graphene_sqlalchemy/batching.py @@ -1,8 +1,10 @@ +import aiodataloader import sqlalchemy -from promise import dataloader, promise from sqlalchemy.orm import Session, strategies from sqlalchemy.orm.query import QueryContext +from .utils import is_sqlalchemy_version_less_than + def get_batch_resolver(relationship_prop): @@ -10,10 +12,10 @@ def get_batch_resolver(relationship_prop): # This is so SQL string generation is cached under-the-hood via `bakery` selectin_loader = strategies.SelectInLoader(relationship_prop, (('lazy', 'selectin'),)) - class RelationshipLoader(dataloader.DataLoader): + class RelationshipLoader(aiodataloader.DataLoader): cache = False - def batch_load_fn(self, parents): # pylint: disable=method-hidden + async def batch_load_fn(self, parents): """ Batch loads the relationships of all the parents as one SQL statement. @@ -52,21 +54,36 @@ def batch_load_fn(self, parents): # pylint: disable=method-hidden states = [(sqlalchemy.inspect(parent), True) for parent in parents] # For our purposes, the query_context will only used to get the session - query_context = QueryContext(session.query(parent_mapper.entity)) - - selectin_loader._load_for_path( - query_context, - parent_mapper._path_registry, - states, - None, - child_mapper, - ) - - return promise.Promise.resolve([getattr(parent, relationship_prop.key) for parent in parents]) + query_context = None + if is_sqlalchemy_version_less_than('1.4'): + query_context = QueryContext(session.query(parent_mapper.entity)) + else: + parent_mapper_query = session.query(parent_mapper.entity) + query_context = parent_mapper_query._compile_context() + + if is_sqlalchemy_version_less_than('1.4'): + selectin_loader._load_for_path( + query_context, + parent_mapper._path_registry, + states, + None, + child_mapper + ) + else: + selectin_loader._load_for_path( + query_context, + parent_mapper._path_registry, + states, + None, + child_mapper, + None + ) + + return [getattr(parent, relationship_prop.key) for parent in parents] loader = RelationshipLoader() - def resolve(root, info, **args): - return loader.load(root) + async def resolve(root, info, **args): + return await loader.load(root) return resolve diff --git a/graphene_sqlalchemy/converter.py b/graphene_sqlalchemy/converter.py index f4b805e2..5d75984b 100644 --- a/graphene_sqlalchemy/converter.py +++ b/graphene_sqlalchemy/converter.py @@ -1,12 +1,16 @@ -from enum import EnumMeta +import datetime +import typing +import warnings +from decimal import Decimal +from functools import singledispatch +from typing import Any -from singledispatch import singledispatch from sqlalchemy import types from sqlalchemy.dialects import postgresql from sqlalchemy.orm import interfaces, strategies -from graphene import (ID, Boolean, Dynamic, Enum, Field, Float, Int, List, - String) +from graphene import (ID, Boolean, Date, DateTime, Dynamic, Enum, Field, Float, + Int, List, String, Time) from graphene.types.json import JSONString from .batching import get_batch_resolver @@ -15,12 +19,24 @@ default_connection_field_factory) from .registry import get_global_registry from .resolvers import get_attr_resolver, get_custom_resolver +from .utils import (registry_sqlalchemy_model_from_str, safe_isinstance, + singledispatchbymatchfunction, value_equals) + +try: + from typing import ForwardRef +except ImportError: + # python 3.6 + from typing import _ForwardRef as ForwardRef try: from sqlalchemy_utils import ChoiceType, JSONType, ScalarListType, TSVectorType except ImportError: ChoiceType = JSONType = ScalarListType = TSVectorType = object +try: + from sqlalchemy_utils.types.choice import EnumTypeImpl +except ImportError: + EnumTypeImpl = object is_selectin_available = getattr(strategies, 'SelectInLoader', None) @@ -44,6 +60,7 @@ def convert_sqlalchemy_relationship(relationship_prop, obj_type, connection_fiel :param dict field_kwargs: :rtype: Dynamic """ + def dynamic_type(): """:rtype: Field|None""" direction = relationship_prop.direction @@ -110,9 +127,11 @@ def _convert_o2m_or_m2m_relationship(relationship_prop, obj_type, batching, conn def convert_sqlalchemy_hybrid_method(hybrid_prop, resolver, **field_kwargs): - if 'type' not in field_kwargs: - # TODO The default type should be dependent on the type of the property propety. - field_kwargs['type'] = String + if 'type_' not in field_kwargs: + field_kwargs['type_'] = convert_hybrid_property_return_type(hybrid_prop) + + if 'description' not in field_kwargs: + field_kwargs['description'] = getattr(hybrid_prop, "__doc__", None) return Field( resolver=resolver, @@ -156,7 +175,8 @@ def inner(fn): def convert_sqlalchemy_column(column_prop, registry, resolver, **field_kwargs): column = column_prop.columns[0] - field_kwargs.setdefault('type', convert_sqlalchemy_type(getattr(column, "type", None), column, registry)) + + field_kwargs.setdefault('type_', convert_sqlalchemy_type(getattr(column, "type", None), column, registry)) field_kwargs.setdefault('required', not is_column_nullable(column)) field_kwargs.setdefault('description', get_column_doc(column)) @@ -221,7 +241,7 @@ def convert_enum_to_enum(type, column, registry=None): @convert_sqlalchemy_type.register(ChoiceType) def convert_choice_to_enum(type, column, registry=None): name = "{}_{}".format(column.table.name, column.name).upper() - if isinstance(type.choices, EnumMeta): + if isinstance(type.type_impl, EnumTypeImpl): # type.choices may be Enum/IntEnum, in ChoiceType both presented as EnumMeta # do not use from_enum here because we can have more than one enum column in table return Enum(name, list((v.name, v.value) for v in type.choices)) @@ -234,11 +254,15 @@ def convert_scalar_list_to_list(type, column, registry=None): return List(String) +def init_array_list_recursive(inner_type, n): + return inner_type if n == 0 else List(init_array_list_recursive(inner_type, n - 1)) + + @convert_sqlalchemy_type.register(types.ARRAY) @convert_sqlalchemy_type.register(postgresql.ARRAY) def convert_array_to_list(_type, column, registry=None): inner_type = convert_sqlalchemy_type(column.type.item_type, column) - return List(inner_type) + return List(init_array_list_recursive(inner_type, (column.type.dimensions or 1) - 1)) @convert_sqlalchemy_type.register(postgresql.HSTORE) @@ -251,3 +275,115 @@ def convert_json_to_string(type, column, registry=None): @convert_sqlalchemy_type.register(JSONType) def convert_json_type_to_string(type, column, registry=None): return JSONString + + +@singledispatchbymatchfunction +def convert_sqlalchemy_hybrid_property_type(arg: Any): + existing_graphql_type = get_global_registry().get_type_for_model(arg) + if existing_graphql_type: + return existing_graphql_type + + # No valid type found, warn and fall back to graphene.String + warnings.warn( + (f"I don't know how to generate a GraphQL type out of a \"{arg}\" type." + "Falling back to \"graphene.String\"") + ) + return String + + +@convert_sqlalchemy_hybrid_property_type.register(value_equals(str)) +def convert_sqlalchemy_hybrid_property_type_str(arg): + return String + + +@convert_sqlalchemy_hybrid_property_type.register(value_equals(int)) +def convert_sqlalchemy_hybrid_property_type_int(arg): + return Int + + +@convert_sqlalchemy_hybrid_property_type.register(value_equals(float)) +def convert_sqlalchemy_hybrid_property_type_float(arg): + return Float + + +@convert_sqlalchemy_hybrid_property_type.register(value_equals(Decimal)) +def convert_sqlalchemy_hybrid_property_type_decimal(arg): + # The reason Decimal should be serialized as a String is because this is a + # base10 type used in things like money, and string allows it to not + # lose precision (which would happen if we downcasted to a Float, for example) + return String + + +@convert_sqlalchemy_hybrid_property_type.register(value_equals(bool)) +def convert_sqlalchemy_hybrid_property_type_bool(arg): + return Boolean + + +@convert_sqlalchemy_hybrid_property_type.register(value_equals(datetime.datetime)) +def convert_sqlalchemy_hybrid_property_type_datetime(arg): + return DateTime + + +@convert_sqlalchemy_hybrid_property_type.register(value_equals(datetime.date)) +def convert_sqlalchemy_hybrid_property_type_date(arg): + return Date + + +@convert_sqlalchemy_hybrid_property_type.register(value_equals(datetime.time)) +def convert_sqlalchemy_hybrid_property_type_time(arg): + return Time + + +@convert_sqlalchemy_hybrid_property_type.register(lambda x: getattr(x, '__origin__', None) == typing.Union) +def convert_sqlalchemy_hybrid_property_type_option_t(arg): + # Option is actually Union[T, ] + + # Just get the T out of the list of arguments by filtering out the NoneType + internal_type = next(filter(lambda x: not type(None) == x, arg.__args__)) + + graphql_internal_type = convert_sqlalchemy_hybrid_property_type(internal_type) + + return graphql_internal_type + + +@convert_sqlalchemy_hybrid_property_type.register(lambda x: getattr(x, '__origin__', None) in [list, typing.List]) +def convert_sqlalchemy_hybrid_property_type_list_t(arg): + # type is either list[T] or List[T], generic argument at __args__[0] + internal_type = arg.__args__[0] + + graphql_internal_type = convert_sqlalchemy_hybrid_property_type(internal_type) + + return List(graphql_internal_type) + + +@convert_sqlalchemy_hybrid_property_type.register(safe_isinstance(ForwardRef)) +def convert_sqlalchemy_hybrid_property_forwardref(arg): + """ + Generate a lambda that will resolve the type at runtime + This takes care of self-references + """ + + def forward_reference_solver(): + model = registry_sqlalchemy_model_from_str(arg.__forward_arg__) + if not model: + return String + # Always fall back to string if no ForwardRef type found. + return get_global_registry().get_type_for_model(model) + + return forward_reference_solver + + +@convert_sqlalchemy_hybrid_property_type.register(safe_isinstance(str)) +def convert_sqlalchemy_hybrid_property_bare_str(arg): + """ + Convert Bare String into a ForwardRef + """ + + return convert_sqlalchemy_hybrid_property_type(ForwardRef(arg)) + + +def convert_hybrid_property_return_type(hybrid_prop): + # Grab the original method's return type annotations from inside the hybrid property + return_type_annotation = hybrid_prop.fget.__annotations__.get('return', str) + + return convert_sqlalchemy_hybrid_property_type(return_type_annotation) diff --git a/graphene_sqlalchemy/enums.py b/graphene_sqlalchemy/enums.py index 35bb51fe..a2ed17ad 100644 --- a/graphene_sqlalchemy/enums.py +++ b/graphene_sqlalchemy/enums.py @@ -1,4 +1,3 @@ -import six from sqlalchemy.orm import ColumnProperty from sqlalchemy.types import Enum as SQLAlchemyEnumType @@ -63,7 +62,7 @@ def enum_for_field(obj_type, field_name): if not isinstance(obj_type, type) or not issubclass(obj_type, SQLAlchemyObjectType): raise TypeError( "Expected SQLAlchemyObjectType, but got: {!r}".format(obj_type)) - if not field_name or not isinstance(field_name, six.string_types): + if not field_name or not isinstance(field_name, str): raise TypeError( "Expected a field name, but got: {!r}".format(field_name)) registry = obj_type._meta.registry diff --git a/graphene_sqlalchemy/fields.py b/graphene_sqlalchemy/fields.py index 780fcbf0..d7a83392 100644 --- a/graphene_sqlalchemy/fields.py +++ b/graphene_sqlalchemy/fields.py @@ -1,17 +1,17 @@ +import enum import warnings from functools import partial -import six from promise import Promise, is_thenable from sqlalchemy.orm.query import Query from graphene import NonNull from graphene.relay import Connection, ConnectionField -from graphene.relay.connection import PageInfo -from graphql_relay.connection.arrayconnection import connection_from_list_slice +from graphene.relay.connection import connection_adapter, page_info_adapter +from graphql_relay import connection_from_array_slice from .batching import get_batch_resolver -from .utils import get_query +from .utils import EnumValue, get_query class UnsortedSQLAlchemyConnectionField(ConnectionField): @@ -19,10 +19,10 @@ class UnsortedSQLAlchemyConnectionField(ConnectionField): def type(self): from .types import SQLAlchemyObjectType - _type = super(ConnectionField, self).type - nullable_type = get_nullable_type(_type) + type_ = super(ConnectionField, self).type + nullable_type = get_nullable_type(type_) if issubclass(nullable_type, Connection): - return _type + return type_ assert issubclass(nullable_type, SQLAlchemyObjectType), ( "SQLALchemyConnectionField only accepts SQLAlchemyObjectType types, not {}" ).format(nullable_type.__name__) @@ -31,7 +31,7 @@ def type(self): ), "The type {} doesn't have a connection".format( nullable_type.__name__ ) - assert _type == nullable_type, ( + assert type_ == nullable_type, ( "Passing a SQLAlchemyObjectType instance is deprecated. " "Pass the connection type instead accessible via SQLAlchemyObjectType.connection" ) @@ -53,15 +53,19 @@ def resolve_connection(cls, connection_type, model, info, args, resolved): _len = resolved.count() else: _len = len(resolved) - connection = connection_from_list_slice( - resolved, - args, + + def adjusted_connection_adapter(edges, pageInfo): + return connection_adapter(connection_type, edges, pageInfo) + + connection = connection_from_array_slice( + array_slice=resolved, + args=args, slice_start=0, - list_length=_len, - list_slice_length=_len, - connection_type=connection_type, - pageinfo_type=PageInfo, + array_length=_len, + array_slice_length=_len, + connection_type=adjusted_connection_adapter, edge_type=connection_type.Edge, + page_info_type=page_info_adapter, ) connection.iterable = resolved connection.length = _len @@ -77,7 +81,7 @@ def connection_resolver(cls, resolver, connection_type, model, root, info, **arg return on_resolve(resolved) - def get_resolver(self, parent_resolver): + def wrap_resolve(self, parent_resolver): return partial( self.connection_resolver, parent_resolver, @@ -88,8 +92,8 @@ def get_resolver(self, parent_resolver): # TODO Rename this to SortableSQLAlchemyConnectionField class SQLAlchemyConnectionField(UnsortedSQLAlchemyConnectionField): - def __init__(self, type, *args, **kwargs): - nullable_type = get_nullable_type(type) + def __init__(self, type_, *args, **kwargs): + nullable_type = get_nullable_type(type_) if "sort" not in kwargs and issubclass(nullable_type, Connection): # Let super class raise if type is not a Connection try: @@ -103,16 +107,25 @@ def __init__(self, type, *args, **kwargs): ) elif "sort" in kwargs and kwargs["sort"] is None: del kwargs["sort"] - super(SQLAlchemyConnectionField, self).__init__(type, *args, **kwargs) + super(SQLAlchemyConnectionField, self).__init__(type_, *args, **kwargs) @classmethod def get_query(cls, model, info, sort=None, **args): query = get_query(model, info.context) if sort is not None: - if isinstance(sort, six.string_types): - query = query.order_by(sort.value) - else: - query = query.order_by(*(col.value for col in sort)) + if not isinstance(sort, list): + sort = [sort] + sort_args = [] + # ensure consistent handling of graphene Enums, enum values and + # plain strings + for item in sort: + if isinstance(item, enum.Enum): + sort_args.append(item.value.value) + elif isinstance(item, EnumValue): + sort_args.append(item.value) + else: + sort_args.append(item) + query = query.order_by(*sort_args) return query @@ -123,7 +136,7 @@ class BatchSQLAlchemyConnectionField(UnsortedSQLAlchemyConnectionField): Use at your own risk. """ - def get_resolver(self, parent_resolver): + def wrap_resolve(self, parent_resolver): return partial( self.connection_resolver, self.resolver, @@ -148,13 +161,13 @@ def default_connection_field_factory(relationship, registry, **field_kwargs): __connectionFactory = UnsortedSQLAlchemyConnectionField -def createConnectionField(_type, **field_kwargs): +def createConnectionField(type_, **field_kwargs): warnings.warn( 'createConnectionField is deprecated and will be removed in the next ' 'major version. Use SQLAlchemyObjectType.Meta.connection_field_factory instead.', DeprecationWarning, ) - return __connectionFactory(_type, **field_kwargs) + return __connectionFactory(type_, **field_kwargs) def registerConnectionFieldFactory(factoryMethod): diff --git a/graphene_sqlalchemy/registry.py b/graphene_sqlalchemy/registry.py index c20bc2ca..acfa744b 100644 --- a/graphene_sqlalchemy/registry.py +++ b/graphene_sqlalchemy/registry.py @@ -1,6 +1,5 @@ from collections import defaultdict -import six from sqlalchemy.types import Enum as SQLAlchemyEnumType from graphene import Enum @@ -43,7 +42,7 @@ def register_orm_field(self, obj_type, field_name, orm_field): raise TypeError( "Expected SQLAlchemyObjectType, but got: {!r}".format(obj_type) ) - if not field_name or not isinstance(field_name, six.string_types): + if not field_name or not isinstance(field_name, str): raise TypeError("Expected a field name, but got: {!r}".format(field_name)) self._registry_orm_fields[obj_type][field_name] = orm_field diff --git a/graphene_sqlalchemy/tests/conftest.py b/graphene_sqlalchemy/tests/conftest.py index 98515051..34ba9d8a 100644 --- a/graphene_sqlalchemy/tests/conftest.py +++ b/graphene_sqlalchemy/tests/conftest.py @@ -22,7 +22,7 @@ def convert_composite_class(composite, registry): return graphene.Field(graphene.Int) -@pytest.yield_fixture(scope="function") +@pytest.fixture(scope="function") def session_factory(): engine = create_engine(test_db_url) Base.metadata.create_all(engine) diff --git a/graphene_sqlalchemy/tests/models.py b/graphene_sqlalchemy/tests/models.py index 88e992b9..e41adb51 100644 --- a/graphene_sqlalchemy/tests/models.py +++ b/graphene_sqlalchemy/tests/models.py @@ -1,6 +1,9 @@ from __future__ import absolute_import +import datetime import enum +from decimal import Decimal +from typing import List, Optional, Tuple from sqlalchemy import (Column, Date, Enum, ForeignKey, Integer, String, Table, func, select) @@ -65,10 +68,35 @@ class Reporter(Base): articles = relationship("Article", backref="reporter") favorite_article = relationship("Article", uselist=False) + @hybrid_property + def hybrid_prop_with_doc(self): + """Docstring test""" + return self.first_name + @hybrid_property def hybrid_prop(self): return self.first_name + @hybrid_property + def hybrid_prop_str(self) -> str: + return self.first_name + + @hybrid_property + def hybrid_prop_int(self) -> int: + return 42 + + @hybrid_property + def hybrid_prop_float(self) -> float: + return 42.3 + + @hybrid_property + def hybrid_prop_bool(self) -> bool: + return True + + @hybrid_property + def hybrid_prop_list(self) -> List[int]: + return [1, 2, 3] + column_prop = column_property( select([func.cast(func.count(id), Integer)]), doc="Column property" ) @@ -95,3 +123,108 @@ def __subclasses__(cls): editor_table = Table("editors", Base.metadata, autoload=True) mapper(ReflectedEditor, editor_table) + + +############################################ +# The models below are mainly used in the +# @hybrid_property type inference scenarios +############################################ + + +class ShoppingCartItem(Base): + __tablename__ = "shopping_cart_items" + + id = Column(Integer(), primary_key=True) + + @hybrid_property + def hybrid_prop_shopping_cart(self) -> List['ShoppingCart']: + return [ShoppingCart(id=1)] + + +class ShoppingCart(Base): + __tablename__ = "shopping_carts" + + id = Column(Integer(), primary_key=True) + + # Standard Library types + + @hybrid_property + def hybrid_prop_str(self) -> str: + return self.first_name + + @hybrid_property + def hybrid_prop_int(self) -> int: + return 42 + + @hybrid_property + def hybrid_prop_float(self) -> float: + return 42.3 + + @hybrid_property + def hybrid_prop_bool(self) -> bool: + return True + + @hybrid_property + def hybrid_prop_decimal(self) -> Decimal: + return Decimal("3.14") + + @hybrid_property + def hybrid_prop_date(self) -> datetime.date: + return datetime.datetime.now().date() + + @hybrid_property + def hybrid_prop_time(self) -> datetime.time: + return datetime.datetime.now().time() + + @hybrid_property + def hybrid_prop_datetime(self) -> datetime.datetime: + return datetime.datetime.now() + + # Lists and Nested Lists + + @hybrid_property + def hybrid_prop_list_int(self) -> List[int]: + return [1, 2, 3] + + @hybrid_property + def hybrid_prop_list_date(self) -> List[datetime.date]: + return [self.hybrid_prop_date, self.hybrid_prop_date, self.hybrid_prop_date] + + @hybrid_property + def hybrid_prop_nested_list_int(self) -> List[List[int]]: + return [self.hybrid_prop_list_int, ] + + @hybrid_property + def hybrid_prop_deeply_nested_list_int(self) -> List[List[List[int]]]: + return [[self.hybrid_prop_list_int, ], ] + + # Other SQLAlchemy Instances + @hybrid_property + def hybrid_prop_first_shopping_cart_item(self) -> ShoppingCartItem: + return ShoppingCartItem(id=1) + + # Other SQLAlchemy Instances + @hybrid_property + def hybrid_prop_shopping_cart_item_list(self) -> List[ShoppingCartItem]: + return [ShoppingCartItem(id=1), ShoppingCartItem(id=2)] + + # Unsupported Type + @hybrid_property + def hybrid_prop_unsupported_type_tuple(self) -> Tuple[str, str]: + return "this will actually", "be a string" + + # Self-references + + @hybrid_property + def hybrid_prop_self_referential(self) -> 'ShoppingCart': + return ShoppingCart(id=1) + + @hybrid_property + def hybrid_prop_self_referential_list(self) -> List['ShoppingCart']: + return [ShoppingCart(id=1)] + + # Optional[T] + + @hybrid_property + def hybrid_prop_optional_self_referential(self) -> Optional['ShoppingCart']: + return None diff --git a/graphene_sqlalchemy/tests/test_batching.py b/graphene_sqlalchemy/tests/test_batching.py index fc646a3c..1896900b 100644 --- a/graphene_sqlalchemy/tests/test_batching.py +++ b/graphene_sqlalchemy/tests/test_batching.py @@ -1,3 +1,4 @@ +import ast import contextlib import logging @@ -9,8 +10,9 @@ from ..fields import (BatchSQLAlchemyConnectionField, default_connection_field_factory) from ..types import ORMField, SQLAlchemyObjectType +from ..utils import is_sqlalchemy_version_less_than from .models import Article, HairKind, Pet, Reporter -from .utils import is_sqlalchemy_version_less_than, to_std_dicts +from .utils import remove_cache_miss_stat, to_std_dicts class MockLoggingHandler(logging.Handler): @@ -75,7 +77,8 @@ def resolve_reporters(self, info): pytest.skip('SQL batching only works for SQLAlchemy 1.2+', allow_module_level=True) -def test_many_to_one(session_factory): +@pytest.mark.asyncio +async def test_many_to_one(session_factory): session = session_factory() reporter_1 = Reporter( @@ -103,7 +106,7 @@ def test_many_to_one(session_factory): with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler: # Starts new session to fully reset the engine / connection logging level session = session_factory() - result = schema.execute(""" + result = await schema.execute_async(""" query { articles { headline @@ -125,26 +128,12 @@ def test_many_to_one(session_factory): assert len(sql_statements) == 1 return - assert messages == [ - 'BEGIN (implicit)', - - 'SELECT articles.id AS articles_id, ' - 'articles.headline AS articles_headline, ' - 'articles.pub_date AS articles_pub_date, ' - 'articles.reporter_id AS articles_reporter_id \n' - 'FROM articles', - '()', - - 'SELECT reporters.id AS reporters_id, ' - '(SELECT CAST(count(reporters.id) AS INTEGER) AS anon_2 \nFROM reporters) AS anon_1, ' - 'reporters.first_name AS reporters_first_name, ' - 'reporters.last_name AS reporters_last_name, ' - 'reporters.email AS reporters_email, ' - 'reporters.favorite_pet_kind AS reporters_favorite_pet_kind \n' - 'FROM reporters \n' - 'WHERE reporters.id IN (?, ?)', - '(1, 2)', - ] + if not is_sqlalchemy_version_less_than('1.4'): + messages[2] = remove_cache_miss_stat(messages[2]) + messages[4] = remove_cache_miss_stat(messages[4]) + + assert ast.literal_eval(messages[2]) == () + assert sorted(ast.literal_eval(messages[4])) == [1, 2] assert not result.errors result = to_std_dicts(result.data) @@ -166,7 +155,8 @@ def test_many_to_one(session_factory): } -def test_one_to_one(session_factory): +@pytest.mark.asyncio +async def test_one_to_one(session_factory): session = session_factory() reporter_1 = Reporter( @@ -194,7 +184,7 @@ def test_one_to_one(session_factory): with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler: # Starts new session to fully reset the engine / connection logging level session = session_factory() - result = schema.execute(""" + result = await schema.execute_async(""" query { reporters { firstName @@ -216,26 +206,12 @@ def test_one_to_one(session_factory): assert len(sql_statements) == 1 return - assert messages == [ - 'BEGIN (implicit)', - - 'SELECT (SELECT CAST(count(reporters.id) AS INTEGER) AS anon_2 \nFROM reporters) AS anon_1, ' - 'reporters.id AS reporters_id, ' - 'reporters.first_name AS reporters_first_name, ' - 'reporters.last_name AS reporters_last_name, ' - 'reporters.email AS reporters_email, ' - 'reporters.favorite_pet_kind AS reporters_favorite_pet_kind \n' - 'FROM reporters', - '()', - - 'SELECT articles.reporter_id AS articles_reporter_id, ' - 'articles.id AS articles_id, ' - 'articles.headline AS articles_headline, ' - 'articles.pub_date AS articles_pub_date \n' - 'FROM articles \n' - 'WHERE articles.reporter_id IN (?, ?)', - '(1, 2)' - ] + if not is_sqlalchemy_version_less_than('1.4'): + messages[2] = remove_cache_miss_stat(messages[2]) + messages[4] = remove_cache_miss_stat(messages[4]) + + assert ast.literal_eval(messages[2]) == () + assert sorted(ast.literal_eval(messages[4])) == [1, 2] assert not result.errors result = to_std_dicts(result.data) @@ -257,7 +233,8 @@ def test_one_to_one(session_factory): } -def test_one_to_many(session_factory): +@pytest.mark.asyncio +async def test_one_to_many(session_factory): session = session_factory() reporter_1 = Reporter( @@ -293,7 +270,7 @@ def test_one_to_many(session_factory): with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler: # Starts new session to fully reset the engine / connection logging level session = session_factory() - result = schema.execute(""" + result = await schema.execute_async(""" query { reporters { firstName @@ -319,26 +296,12 @@ def test_one_to_many(session_factory): assert len(sql_statements) == 1 return - assert messages == [ - 'BEGIN (implicit)', - - 'SELECT (SELECT CAST(count(reporters.id) AS INTEGER) AS anon_2 \nFROM reporters) AS anon_1, ' - 'reporters.id AS reporters_id, ' - 'reporters.first_name AS reporters_first_name, ' - 'reporters.last_name AS reporters_last_name, ' - 'reporters.email AS reporters_email, ' - 'reporters.favorite_pet_kind AS reporters_favorite_pet_kind \n' - 'FROM reporters', - '()', - - 'SELECT articles.reporter_id AS articles_reporter_id, ' - 'articles.id AS articles_id, ' - 'articles.headline AS articles_headline, ' - 'articles.pub_date AS articles_pub_date \n' - 'FROM articles \n' - 'WHERE articles.reporter_id IN (?, ?)', - '(1, 2)' - ] + if not is_sqlalchemy_version_less_than('1.4'): + messages[2] = remove_cache_miss_stat(messages[2]) + messages[4] = remove_cache_miss_stat(messages[4]) + + assert ast.literal_eval(messages[2]) == () + assert sorted(ast.literal_eval(messages[4])) == [1, 2] assert not result.errors result = to_std_dicts(result.data) @@ -382,7 +345,8 @@ def test_one_to_many(session_factory): } -def test_many_to_many(session_factory): +@pytest.mark.asyncio +async def test_many_to_many(session_factory): session = session_factory() reporter_1 = Reporter( @@ -420,7 +384,7 @@ def test_many_to_many(session_factory): with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler: # Starts new session to fully reset the engine / connection logging level session = session_factory() - result = schema.execute(""" + result = await schema.execute_async(""" query { reporters { firstName @@ -446,31 +410,12 @@ def test_many_to_many(session_factory): assert len(sql_statements) == 1 return - assert messages == [ - 'BEGIN (implicit)', - - 'SELECT (SELECT CAST(count(reporters.id) AS INTEGER) AS anon_2 \nFROM reporters) AS anon_1, ' - 'reporters.id AS reporters_id, ' - 'reporters.first_name AS reporters_first_name, ' - 'reporters.last_name AS reporters_last_name, ' - 'reporters.email AS reporters_email, ' - 'reporters.favorite_pet_kind AS reporters_favorite_pet_kind \n' - 'FROM reporters', - '()', - - 'SELECT reporters_1.id AS reporters_1_id, ' - 'pets.id AS pets_id, ' - 'pets.name AS pets_name, ' - 'pets.pet_kind AS pets_pet_kind, ' - 'pets.hair_kind AS pets_hair_kind, ' - 'pets.reporter_id AS pets_reporter_id \n' - 'FROM reporters AS reporters_1 ' - 'JOIN association AS association_1 ON reporters_1.id = association_1.reporter_id ' - 'JOIN pets ON pets.id = association_1.pet_id \n' - 'WHERE reporters_1.id IN (?, ?) ' - 'ORDER BY pets.id', - '(1, 2)' - ] + if not is_sqlalchemy_version_less_than('1.4'): + messages[2] = remove_cache_miss_stat(messages[2]) + messages[4] = remove_cache_miss_stat(messages[4]) + + assert ast.literal_eval(messages[2]) == () + assert sorted(ast.literal_eval(messages[4])) == [1, 2] assert not result.errors result = to_std_dicts(result.data) @@ -586,7 +531,8 @@ def resolve_reporters(self, info): assert len(select_statements) == 2 -def test_connection_factory_field_overrides_batching_is_false(session_factory): +@pytest.mark.asyncio +async def test_connection_factory_field_overrides_batching_is_false(session_factory): session = session_factory() reporter_1 = Reporter(first_name='Reporter_1') session.add(reporter_1) @@ -620,7 +566,7 @@ def resolve_reporters(self, info): with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler: # Starts new session to fully reset the engine / connection logging level session = session_factory() - schema.execute(""" + await schema.execute_async(""" query { reporters { articles { diff --git a/graphene_sqlalchemy/tests/test_benchmark.py b/graphene_sqlalchemy/tests/test_benchmark.py index 1e5ee4f1..11e9d0e0 100644 --- a/graphene_sqlalchemy/tests/test_benchmark.py +++ b/graphene_sqlalchemy/tests/test_benchmark.py @@ -1,13 +1,11 @@ import pytest -from graphql.backend import GraphQLCachedBackend, GraphQLCoreBackend import graphene from graphene import relay -from ..fields import BatchSQLAlchemyConnectionField from ..types import SQLAlchemyObjectType +from ..utils import is_sqlalchemy_version_less_than from .models import Article, HairKind, Pet, Reporter -from .utils import is_sqlalchemy_version_less_than if is_sqlalchemy_version_less_than('1.2'): pytest.skip('SQL batching only works for SQLAlchemy 1.2+', allow_module_level=True) @@ -18,19 +16,16 @@ class ReporterType(SQLAlchemyObjectType): class Meta: model = Reporter interfaces = (relay.Node,) - connection_field_factory = BatchSQLAlchemyConnectionField.from_relationship class ArticleType(SQLAlchemyObjectType): class Meta: model = Article interfaces = (relay.Node,) - connection_field_factory = BatchSQLAlchemyConnectionField.from_relationship class PetType(SQLAlchemyObjectType): class Meta: model = Pet interfaces = (relay.Node,) - connection_field_factory = BatchSQLAlchemyConnectionField.from_relationship class Query(graphene.ObjectType): articles = graphene.Field(graphene.List(ArticleType)) @@ -47,15 +42,12 @@ def resolve_reporters(self, info): def benchmark_query(session_factory, benchmark, query): schema = get_schema() - cached_backend = GraphQLCachedBackend(GraphQLCoreBackend()) - cached_backend.document_from_string(schema, query) # Prime cache @benchmark def execute_query(): result = schema.execute( query, context_value={"session": session_factory()}, - backend=cached_backend, ) assert not result.errors diff --git a/graphene_sqlalchemy/tests/test_converter.py b/graphene_sqlalchemy/tests/test_converter.py index f0fc1802..70e11713 100644 --- a/graphene_sqlalchemy/tests/test_converter.py +++ b/graphene_sqlalchemy/tests/test_converter.py @@ -1,4 +1,5 @@ import enum +from typing import Dict, Union import pytest from sqlalchemy import Column, func, select, types @@ -9,9 +10,11 @@ from sqlalchemy_utils import ChoiceType, JSONType, ScalarListType import graphene +from graphene import Boolean, Float, Int, Scalar, String from graphene.relay import Node -from graphene.types.datetime import DateTime +from graphene.types.datetime import Date, DateTime, Time from graphene.types.json import JSONString +from graphene.types.structures import List, Structure from ..converter import (convert_sqlalchemy_column, convert_sqlalchemy_composite, @@ -20,7 +23,8 @@ default_connection_field_factory) from ..registry import Registry, get_global_registry from ..types import SQLAlchemyObjectType -from .models import Article, CompositeFullName, Pet, Reporter +from .models import (Article, CompositeFullName, Pet, Reporter, ShoppingCart, + ShoppingCartItem) def mock_resolver(): @@ -51,7 +55,7 @@ def test_should_unknown_sqlalchemy_field_raise_exception(): re_err = "Don't know how to convert the SQLAlchemy field" with pytest.raises(Exception, match=re_err): # support legacy Binary type and subsequent LargeBinary - get_field(getattr(types, 'LargeBinary', types.Binary)()) + get_field(getattr(types, 'LargeBinary', types.BINARY)()) def test_should_date_convert_string(): @@ -324,6 +328,21 @@ def test_should_array_convert(): assert field.type.of_type == graphene.Int +def test_should_2d_array_convert(): + field = get_field(types.ARRAY(types.Integer, dimensions=2)) + assert isinstance(field.type, graphene.List) + assert isinstance(field.type.of_type, graphene.List) + assert field.type.of_type.of_type == graphene.Int + + +def test_should_3d_array_convert(): + field = get_field(types.ARRAY(types.Integer, dimensions=3)) + assert isinstance(field.type, graphene.List) + assert isinstance(field.type.of_type, graphene.List) + assert isinstance(field.type.of_type.of_type, graphene.List) + assert field.type.of_type.of_type.of_type == graphene.Int + + def test_should_postgresql_json_convert(): assert get_field(postgresql.JSON()).type == graphene.JSONString @@ -369,3 +388,88 @@ def __init__(self, col1, col2): Registry(), mock_resolver, ) + + +def test_sqlalchemy_hybrid_property_type_inference(): + class ShoppingCartItemType(SQLAlchemyObjectType): + class Meta: + model = ShoppingCartItem + interfaces = (Node,) + + class ShoppingCartType(SQLAlchemyObjectType): + class Meta: + model = ShoppingCart + interfaces = (Node,) + + ####################################################### + # Check ShoppingCartItem's Properties and Return Types + ####################################################### + + shopping_cart_item_expected_types: Dict[str, Union[Scalar, Structure]] = { + 'hybrid_prop_shopping_cart': List(ShoppingCartType) + } + + assert sorted(list(ShoppingCartItemType._meta.fields.keys())) == sorted([ + # Columns + "id", + # Append Hybrid Properties from Above + *shopping_cart_item_expected_types.keys() + ]) + + for hybrid_prop_name, hybrid_prop_expected_return_type in shopping_cart_item_expected_types.items(): + hybrid_prop_field = ShoppingCartItemType._meta.fields[hybrid_prop_name] + + # this is a simple way of showing the failed property name + # instead of having to unroll the loop. + assert ( + (hybrid_prop_name, str(hybrid_prop_field.type)) == + (hybrid_prop_name, str(hybrid_prop_expected_return_type)) + ) + assert hybrid_prop_field.description is None # "doc" is ignored by hybrid property + + ################################################### + # Check ShoppingCart's Properties and Return Types + ################################################### + + shopping_cart_expected_types: Dict[str, Union[Scalar, Structure]] = { + # Basic types + "hybrid_prop_str": String, + "hybrid_prop_int": Int, + "hybrid_prop_float": Float, + "hybrid_prop_bool": Boolean, + "hybrid_prop_decimal": String, # Decimals should be serialized Strings + "hybrid_prop_date": Date, + "hybrid_prop_time": Time, + "hybrid_prop_datetime": DateTime, + # Lists and Nested Lists + "hybrid_prop_list_int": List(Int), + "hybrid_prop_list_date": List(Date), + "hybrid_prop_nested_list_int": List(List(Int)), + "hybrid_prop_deeply_nested_list_int": List(List(List(Int))), + "hybrid_prop_first_shopping_cart_item": ShoppingCartItemType, + "hybrid_prop_shopping_cart_item_list": List(ShoppingCartItemType), + "hybrid_prop_unsupported_type_tuple": String, + # Self Referential List + "hybrid_prop_self_referential": ShoppingCartType, + "hybrid_prop_self_referential_list": List(ShoppingCartType), + # Optionals + "hybrid_prop_optional_self_referential": ShoppingCartType, + } + + assert sorted(list(ShoppingCartType._meta.fields.keys())) == sorted([ + # Columns + "id", + # Append Hybrid Properties from Above + *shopping_cart_expected_types.keys() + ]) + + for hybrid_prop_name, hybrid_prop_expected_return_type in shopping_cart_expected_types.items(): + hybrid_prop_field = ShoppingCartType._meta.fields[hybrid_prop_name] + + # this is a simple way of showing the failed property name + # instead of having to unroll the loop. + assert ( + (hybrid_prop_name, str(hybrid_prop_field.type)) == + (hybrid_prop_name, str(hybrid_prop_expected_return_type)) + ) + assert hybrid_prop_field.description is None # "doc" is ignored by hybrid property diff --git a/graphene_sqlalchemy/tests/test_query_enums.py b/graphene_sqlalchemy/tests/test_query_enums.py index ec585d57..5166c45f 100644 --- a/graphene_sqlalchemy/tests/test_query_enums.py +++ b/graphene_sqlalchemy/tests/test_query_enums.py @@ -32,7 +32,7 @@ def resolve_reporters(self, _info): def resolve_pets(self, _info, kind): query = session.query(Pet) if kind: - query = query.filter_by(pet_kind=kind) + query = query.filter_by(pet_kind=kind.value) return query query = """ @@ -131,7 +131,7 @@ class Query(graphene.ObjectType): def resolve_pet(self, info, kind=None): query = session.query(Pet) if kind: - query = query.filter(Pet.pet_kind == kind) + query = query.filter(Pet.pet_kind == kind.value) return query.first() query = """ diff --git a/graphene_sqlalchemy/tests/test_sort_enums.py b/graphene_sqlalchemy/tests/test_sort_enums.py index d6f6965d..6291d4f8 100644 --- a/graphene_sqlalchemy/tests/test_sort_enums.py +++ b/graphene_sqlalchemy/tests/test_sort_enums.py @@ -354,7 +354,7 @@ def makeNodes(nodeList): """ result = schema.execute(queryError, context_value={"session": session}) assert result.errors is not None - assert '"sort" has invalid value' in result.errors[0].message + assert 'cannot represent non-enum value' in result.errors[0].message queryNoSort = """ query sortTest { diff --git a/graphene_sqlalchemy/tests/test_types.py b/graphene_sqlalchemy/tests/test_types.py index bf563b6e..9a2e992d 100644 --- a/graphene_sqlalchemy/tests/test_types.py +++ b/graphene_sqlalchemy/tests/test_types.py @@ -1,11 +1,14 @@ -import mock +from unittest import mock + import pytest -import six # noqa F401 +import sqlalchemy.exc +import sqlalchemy.orm.exc -from graphene import (Dynamic, Field, GlobalID, Int, List, Node, NonNull, - ObjectType, Schema, String) +from graphene import (Boolean, Dynamic, Field, Float, GlobalID, Int, List, + Node, NonNull, ObjectType, Schema, String) from graphene.relay import Connection +from .. import utils from ..converter import convert_sqlalchemy_composite from ..fields import (SQLAlchemyConnectionField, UnsortedSQLAlchemyConnectionField, createConnectionField, @@ -71,7 +74,7 @@ class Meta: model = Article interfaces = (Node,) - assert list(ReporterType._meta.fields.keys()) == [ + assert sorted(list(ReporterType._meta.fields.keys())) == sorted([ # Columns "column_prop", # SQLAlchemy retuns column properties first "id", @@ -82,12 +85,18 @@ class Meta: # Composite "composite_prop", # Hybrid + "hybrid_prop_with_doc", "hybrid_prop", + "hybrid_prop_str", + "hybrid_prop_int", + "hybrid_prop_float", + "hybrid_prop_bool", + "hybrid_prop_list", # Relationship "pets", "articles", "favorite_article", - ] + ]) # column first_name_field = ReporterType._meta.fields['first_name'] @@ -112,6 +121,42 @@ class Meta: # "doc" is ignored by hybrid_property assert hybrid_prop.description is None + # hybrid_property_str + hybrid_prop_str = ReporterType._meta.fields['hybrid_prop_str'] + assert hybrid_prop_str.type == String + # "doc" is ignored by hybrid_property + assert hybrid_prop_str.description is None + + # hybrid_property_int + hybrid_prop_int = ReporterType._meta.fields['hybrid_prop_int'] + assert hybrid_prop_int.type == Int + # "doc" is ignored by hybrid_property + assert hybrid_prop_int.description is None + + # hybrid_property_float + hybrid_prop_float = ReporterType._meta.fields['hybrid_prop_float'] + assert hybrid_prop_float.type == Float + # "doc" is ignored by hybrid_property + assert hybrid_prop_float.description is None + + # hybrid_property_bool + hybrid_prop_bool = ReporterType._meta.fields['hybrid_prop_bool'] + assert hybrid_prop_bool.type == Boolean + # "doc" is ignored by hybrid_property + assert hybrid_prop_bool.description is None + + # hybrid_property_list + hybrid_prop_list = ReporterType._meta.fields['hybrid_prop_list'] + assert hybrid_prop_list.type == List(Int) + # "doc" is ignored by hybrid_property + assert hybrid_prop_list.description is None + + # hybrid_prop_with_doc + hybrid_prop_with_doc = ReporterType._meta.fields['hybrid_prop_with_doc'] + assert hybrid_prop_with_doc.type == String + # docstring is picked up from hybrid_prop_with_doc + assert hybrid_prop_with_doc.description == "Docstring test" + # relationship favorite_article_field = ReporterType._meta.fields['favorite_article'] assert isinstance(favorite_article_field, Dynamic) @@ -136,15 +181,16 @@ class Meta: # columns email = ORMField(deprecation_reason='Overridden') - email_v2 = ORMField(model_attr='email', type=Int) + email_v2 = ORMField(model_attr='email', type_=Int) # column_property - column_prop = ORMField(type=String) + column_prop = ORMField(type_=String) # composite composite_prop = ORMField() # hybrid_property + hybrid_prop_with_doc = ORMField(description='Overridden') hybrid_prop = ORMField(description='Overridden') # relationships @@ -163,7 +209,7 @@ class Meta: interfaces = (Node,) use_connection = False - assert list(ReporterType._meta.fields.keys()) == [ + assert sorted(list(ReporterType._meta.fields.keys())) == sorted([ # Fields from ReporterMixin "first_name", "last_name", @@ -172,6 +218,7 @@ class Meta: "email_v2", "column_prop", "composite_prop", + "hybrid_prop_with_doc", "hybrid_prop", "favorite_article", "articles", @@ -179,7 +226,12 @@ class Meta: # Then the automatic SQLAlchemy fields "id", "favorite_pet_kind", - ] + "hybrid_prop_str", + "hybrid_prop_int", + "hybrid_prop_float", + "hybrid_prop_bool", + "hybrid_prop_list", + ]) first_name_field = ReporterType._meta.fields['first_name'] assert isinstance(first_name_field.type, NonNull) @@ -207,6 +259,11 @@ class Meta: assert hybrid_prop_field.description == "Overridden" assert hybrid_prop_field.deprecation_reason is None + hybrid_prop_with_doc_field = ReporterType._meta.fields['hybrid_prop_with_doc'] + assert hybrid_prop_with_doc_field.type == String + assert hybrid_prop_with_doc_field.description == "Overridden" + assert hybrid_prop_with_doc_field.deprecation_reason is None + column_prop_field_v2 = ReporterType._meta.fields['column_prop'] assert column_prop_field_v2.type == String assert column_prop_field_v2.description is None @@ -268,18 +325,24 @@ class Meta: first_name = ORMField() # Takes precedence last_name = ORMField() # Noop - assert list(ReporterType._meta.fields.keys()) == [ + assert sorted(list(ReporterType._meta.fields.keys())) == sorted([ "first_name", "last_name", "column_prop", "email", "favorite_pet_kind", "composite_prop", + "hybrid_prop_with_doc", "hybrid_prop", + "hybrid_prop_str", + "hybrid_prop_int", + "hybrid_prop_float", + "hybrid_prop_bool", + "hybrid_prop_list", "pets", "articles", "favorite_article", - ] + ]) def test_only_and_exclude_fields(): @@ -384,7 +447,7 @@ class Meta: assert issubclass(CustomReporterType, ObjectType) assert CustomReporterType._meta.model == Reporter - assert len(CustomReporterType._meta.fields) == 11 + assert len(CustomReporterType._meta.fields) == 17 # Test Custom SQLAlchemyObjectType with Custom Options @@ -492,3 +555,65 @@ class Meta: def test_deprecated_createConnectionField(): with pytest.warns(DeprecationWarning): createConnectionField(None) + + +@mock.patch(utils.__name__ + '.class_mapper') +def test_unique_errors_propagate(class_mapper_mock): + # Define unique error to detect + class UniqueError(Exception): + pass + + # Mock class_mapper effect + class_mapper_mock.side_effect = UniqueError + + # Make sure that errors are propagated from class_mapper when instantiating new classes + error = None + try: + class ArticleOne(SQLAlchemyObjectType): + class Meta(object): + model = Article + except UniqueError as e: + error = e + + # Check that an error occured, and that it was the unique error we gave + assert error is not None + assert isinstance(error, UniqueError) + + +@mock.patch(utils.__name__ + '.class_mapper') +def test_argument_errors_propagate(class_mapper_mock): + # Mock class_mapper effect + class_mapper_mock.side_effect = sqlalchemy.exc.ArgumentError + + # Make sure that errors are propagated from class_mapper when instantiating new classes + error = None + try: + class ArticleTwo(SQLAlchemyObjectType): + class Meta(object): + model = Article + except sqlalchemy.exc.ArgumentError as e: + error = e + + # Check that an error occured, and that it was the unique error we gave + assert error is not None + assert isinstance(error, sqlalchemy.exc.ArgumentError) + + +@mock.patch(utils.__name__ + '.class_mapper') +def test_unmapped_errors_reformat(class_mapper_mock): + # Mock class_mapper effect + class_mapper_mock.side_effect = sqlalchemy.orm.exc.UnmappedClassError(object) + + # Make sure that errors are propagated from class_mapper when instantiating new classes + error = None + try: + class ArticleThree(SQLAlchemyObjectType): + class Meta(object): + model = Article + except ValueError as e: + error = e + + # Check that an error occured, and that it was the unique error we gave + assert error is not None + assert isinstance(error, ValueError) + assert "You need to pass a valid SQLAlchemy Model" in str(error) diff --git a/graphene_sqlalchemy/tests/utils.py b/graphene_sqlalchemy/tests/utils.py index 428757c3..c90ee476 100644 --- a/graphene_sqlalchemy/tests/utils.py +++ b/graphene_sqlalchemy/tests/utils.py @@ -1,4 +1,4 @@ -import pkg_resources +import re def to_std_dicts(value): @@ -11,6 +11,7 @@ def to_std_dicts(value): return value -def is_sqlalchemy_version_less_than(version_string): - """Check the installed SQLAlchemy version""" - return pkg_resources.get_distribution('SQLAlchemy').parsed_version < pkg_resources.parse_version(version_string) +def remove_cache_miss_stat(message): + """Remove the stat from the echoed query message when the cache is missed for sqlalchemy version >= 1.4""" + # https://github.com/sqlalchemy/sqlalchemy/blob/990eb3d8813369d3b8a7776ae85fb33627443d30/lib/sqlalchemy/engine/default.py#L1177 + return re.sub(r"\[generated in \d+.?\d*s\]\s", "", message) diff --git a/graphene_sqlalchemy/types.py b/graphene_sqlalchemy/types.py index ff22cded..ac69b697 100644 --- a/graphene_sqlalchemy/types.py +++ b/graphene_sqlalchemy/types.py @@ -27,7 +27,7 @@ class ORMField(OrderedType): def __init__( self, model_attr=None, - type=None, + type_=None, required=None, description=None, deprecation_reason=None, @@ -49,7 +49,7 @@ class MyType(SQLAlchemyObjectType): class Meta: model = MyModel - id = ORMField(type=graphene.Int) + id = ORMField(type_=graphene.Int) name = ORMField(required=True) -> MyType.id will be of type Int (vs ID). @@ -58,7 +58,7 @@ class Meta: :param str model_attr: Name of the SQLAlchemy model attribute used to resolve this field. Default to the name of the attribute referencing the ORMField. - :param type: + :param type_: Default to the type mapping in converter.py. :param str description: Default to the `doc` attribute of the SQLAlchemy column property. @@ -77,7 +77,7 @@ class Meta: # The is only useful for documentation and auto-completion common_kwargs = { 'model_attr': model_attr, - 'type': type, + 'type_': type_, 'required': required, 'description': description, 'deprecation_reason': deprecation_reason, @@ -207,9 +207,11 @@ def __init_subclass_with_meta__( _meta=None, **options ): - assert is_mapped_class(model), ( - "You need to pass a valid SQLAlchemy Model in " '{}.Meta, received "{}".' - ).format(cls.__name__, model) + # Make sure model is a valid SQLAlchemy model + if not is_mapped_class(model): + raise ValueError( + "You need to pass a valid SQLAlchemy Model in " '{}.Meta, received "{}".'.format(cls.__name__, model) + ) if not registry: registry = get_global_registry() diff --git a/graphene_sqlalchemy/utils.py b/graphene_sqlalchemy/utils.py index 7139eefc..301e782c 100644 --- a/graphene_sqlalchemy/utils.py +++ b/graphene_sqlalchemy/utils.py @@ -1,10 +1,15 @@ import re import warnings +from collections import OrderedDict +from typing import Any, Callable, Dict, Optional +import pkg_resources from sqlalchemy.exc import ArgumentError from sqlalchemy.orm import class_mapper, object_mapper from sqlalchemy.orm.exc import UnmappedClassError, UnmappedInstanceError +from graphene_sqlalchemy.registry import get_global_registry + def get_session(context): return context.get("session") @@ -26,7 +31,13 @@ def get_query(model, context): def is_mapped_class(cls): try: class_mapper(cls) - except (ArgumentError, UnmappedClassError): + except ArgumentError as error: + # Only handle ArgumentErrors for non-class objects + if "Class object expected" in str(error): + return False + raise + except UnmappedClassError: + # Unmapped classes return false return False else: return True @@ -80,7 +91,6 @@ def _deprecated_default_symbol_name(column_name, sort_asc): def _deprecated_object_type_for_model(cls, name): - try: return _deprecated_object_type_cache[cls, name] except KeyError: @@ -140,3 +150,59 @@ def sort_argument_for_model(cls, has_default=True): enum.default = None return Argument(List(enum), default_value=enum.default) + + +def is_sqlalchemy_version_less_than(version_string): + """Check the installed SQLAlchemy version""" + return pkg_resources.get_distribution('SQLAlchemy').parsed_version < pkg_resources.parse_version(version_string) + + +class singledispatchbymatchfunction: + """ + Inspired by @singledispatch, this is a variant that works using a matcher function + instead of relying on the type of the first argument. + The register method can be used to register a new matcher, which is passed as the first argument: + """ + + def __init__(self, default: Callable): + self.registry: Dict[Callable, Callable] = OrderedDict() + self.default = default + + def __call__(self, *args, **kwargs): + for matcher_function, final_method in self.registry.items(): + # Register order is important. First one that matches, runs. + if matcher_function(args[0]): + return final_method(*args, **kwargs) + + # No match, using default. + return self.default(*args, **kwargs) + + def register(self, matcher_function: Callable[[Any], bool]): + + def grab_function_from_outside(f): + self.registry[matcher_function] = f + return self + + return grab_function_from_outside + + +def value_equals(value): + """A simple function that makes the equality based matcher functions for + SingleDispatchByMatchFunction prettier""" + return lambda x: x == value + + +def safe_isinstance(cls): + def safe_isinstance_checker(arg): + try: + return isinstance(arg, cls) + except TypeError: + pass + return safe_isinstance_checker + + +def registry_sqlalchemy_model_from_str(model_name: str) -> Optional[Any]: + try: + return next(filter(lambda x: x.__name__ == model_name, list(get_global_registry()._registry.keys()))) + except StopIteration: + pass diff --git a/setup.cfg b/setup.cfg index 4e8e5029..f36334d8 100644 --- a/setup.cfg +++ b/setup.cfg @@ -9,7 +9,7 @@ max-line-length = 120 no_lines_before=FIRSTPARTY known_graphene=graphene,graphql_relay,flask_graphql,graphql_server,sphinx_graphene_theme known_first_party=graphene_sqlalchemy -known_third_party=app,database,flask,graphql,mock,models,nameko,pkg_resources,promise,pytest,schema,setuptools,singledispatch,six,sqlalchemy,sqlalchemy_utils +known_third_party=aiodataloader,app,database,flask,models,nameko,pkg_resources,promise,pytest,schema,setuptools,sqlalchemy,sqlalchemy_utils sections=FUTURE,STDLIB,THIRDPARTY,GRAPHENE,FIRSTPARTY,LOCALFOLDER skip_glob=examples/nameko_sqlalchemy diff --git a/setup.py b/setup.py index 7b350c39..da49f1d4 100644 --- a/setup.py +++ b/setup.py @@ -13,24 +13,18 @@ requirements = [ # To keep things simple, we only support newer versions of Graphene - "graphene>=2.1.3,<3", + "graphene>=3.0.0b7", "promise>=2.3", - # Tests fail with 1.0.19 - "SQLAlchemy>=1.2,<2", - "six>=1.10.0,<2", - "singledispatch>=3.4.0.3,<4", + "SQLAlchemy>=1.1,<2", + "aiodataloader>=0.2.0,<1.0", ] -try: - import enum -except ImportError: # Python < 2.7 and Python 3.3 - requirements.append("enum34 >= 1.1.6") tests_require = [ - "pytest==4.3.1", - "mock==2.0.0", - "pytest-cov==2.6.1", - "sqlalchemy_utils==0.33.9", - "pytest-benchmark==3.2.1", + "pytest>=6.2.0,<7.0", + "pytest-asyncio>=0.15.1", + "pytest-cov>=2.11.0,<3.0", + "sqlalchemy_utils>=0.37.0,<1.0", + "pytest-benchmark>=3.4.0,<4.0", ] setup( @@ -46,12 +40,10 @@ "Development Status :: 3 - Alpha", "Intended Audience :: Developers", "Topic :: Software Development :: Libraries", - "Programming Language :: Python :: 2", - "Programming Language :: Python :: 2.7", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.5", "Programming Language :: Python :: 3.6", "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", "Programming Language :: Python :: Implementation :: PyPy", ], keywords="api graphql protocol rest relay graphene", @@ -60,8 +52,8 @@ extras_require={ "dev": [ "tox==3.7.0", # Should be kept in sync with tox.ini - "coveralls==1.10.0", "pre-commit==1.14.4", + "flake8==3.7.9", ], "test": tests_require, }, diff --git a/tox.ini b/tox.ini index 562da2dc..b8ce0618 100644 --- a/tox.ini +++ b/tox.ini @@ -1,20 +1,40 @@ [tox] -envlist = pre-commit,py{27,35,36,37}-sql{11,12,13} +envlist = pre-commit,py{36,37,38,39}-sql{12,13,14} skipsdist = true minversion = 3.7.0 +[gh-actions] +python = + 3.6: py36 + 3.7: py37 + 3.8: py38 + 3.9: py39 + +[gh-actions:env] +SQLALCHEMY = + 1.2: sql12 + 1.3: sql13 + 1.4: sql14 + [testenv] +passenv = GITHUB_* deps = .[test] - sql11: sqlalchemy>=1.1,<1.2 sql12: sqlalchemy>=1.2,<1.3 sql13: sqlalchemy>=1.3,<1.4 + sql14: sqlalchemy>=1.4,<1.5 commands = - pytest graphene_sqlalchemy --cov=graphene_sqlalchemy {posargs} + pytest graphene_sqlalchemy --cov=graphene_sqlalchemy --cov-report=term --cov-report=xml {posargs} [testenv:pre-commit] -basepython=python3.7 +basepython=python3.9 deps = .[dev] commands = pre-commit {posargs:run --all-files} + +[testenv:flake8] +basepython = python3.9 +deps = -e.[dev] +commands = + flake8 --exclude setup.py,docs,examples,tests,.tox --max-line-length 120