Skip to content

Commit

Permalink
Merge branch 'master' into enable-sorting-for-batching
Browse files Browse the repository at this point in the history
  • Loading branch information
sabard committed Sep 9, 2022
2 parents d6746e3 + bb7af4b commit 027ea0e
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 4 deletions.
2 changes: 1 addition & 1 deletion graphene_sqlalchemy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from .types import SQLAlchemyObjectType
from .utils import get_query, get_session

__version__ = "3.0.0b2"
__version__ = "3.0.0b3"

__all__ = [
"__version__",
Expand Down
19 changes: 17 additions & 2 deletions graphene_sqlalchemy/batching.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
"""The dataloader uses "select in loading" strategy to load related entities."""
from asyncio import get_event_loop
from typing import Dict
from typing import Any, Dict

import aiodataloader
import sqlalchemy
from sqlalchemy.orm import Session, strategies
from sqlalchemy.orm.query import QueryContext

from .utils import is_sqlalchemy_version_less_than
from .utils import (is_graphene_version_less_than,
is_sqlalchemy_version_less_than)


class RelationshipLoader(aiodataloader.DataLoader):
Expand Down Expand Up @@ -94,6 +95,20 @@ async def batch_load_fn(self, parents):
] = {}


def get_data_loader_impl() -> Any: # pragma: no cover
"""Graphene >= 3.1.1 ships a copy of aiodataloader with minor fixes. To preserve backward-compatibility,
aiodataloader is used in conjunction with older versions of graphene"""
if is_graphene_version_less_than("3.1.1"):
from aiodataloader import DataLoader
else:
from graphene.utils.dataloader import DataLoader

return DataLoader


DataLoader = get_data_loader_impl()


def get_batch_resolver(relationship_prop):
"""Get the resolve function for the given relationship."""

Expand Down
9 changes: 8 additions & 1 deletion graphene_sqlalchemy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,11 +151,16 @@ def sort_argument_for_model(cls, has_default=True):
return Argument(List(enum), default_value=enum.default)


def is_sqlalchemy_version_less_than(version_string):
def is_sqlalchemy_version_less_than(version_string): # pragma: no cover
"""Check the installed SQLAlchemy version"""
return pkg_resources.get_distribution('SQLAlchemy').parsed_version < pkg_resources.parse_version(version_string)


def is_graphene_version_less_than(version_string): # pragma: no cover
"""Check the installed graphene version"""
return pkg_resources.get_distribution('graphene').parsed_version < pkg_resources.parse_version(version_string)


class singledispatchbymatchfunction:
"""
Inspired by @singledispatch, this is a variant that works using a matcher function
Expand Down Expand Up @@ -197,6 +202,7 @@ def safe_isinstance_checker(arg):
return isinstance(arg, cls)
except TypeError:
pass

return safe_isinstance_checker


Expand All @@ -210,5 +216,6 @@ def registry_sqlalchemy_model_from_str(model_name: str) -> Optional[Any]:

class DummyImport:
"""The dummy module returns 'object' for a query for any member"""

def __getattr__(self, name):
return object

0 comments on commit 027ea0e

Please sign in to comment.