Skip to content

Commit

Permalink
add repositories order by, offset & limit, clean-up code
Browse files Browse the repository at this point in the history
  • Loading branch information
douwevandermeij committed Mar 31, 2023
1 parent 2a52bb4 commit 2ef1adf
Show file tree
Hide file tree
Showing 17 changed files with 309 additions and 152 deletions.
2 changes: 1 addition & 1 deletion fractal/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Fractal is a scaffolding toolkit for building SOLID logic for your Python applications."""

__version__ = "3.0.3"
__version__ = "3.1.0"

from abc import ABC

Expand Down
21 changes: 12 additions & 9 deletions fractal/contrib/django/repositories.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,17 +66,20 @@ def find_one(self, specification: Specification) -> Optional[Entity]:
return self._obj_to_domain(self.__get_obj(specification).__dict__)

def find(
self, specification: Specification = None
self,
specification: Specification = None,
*,
offset: int = 0,
limit: int = 0,
order_by: str = "id",
) -> Generator[Entity, None, None]:
_filter = DjangoOrmSpecificationBuilder.build(specification)
if type(_filter) is list:
queryset = self.django_model.objects.filter(*_filter)
elif type(_filter) is dict:
queryset = self.django_model.objects.filter(**_filter)
elif type(_filter) is Q:
queryset = self.django_model.objects.filter(_filter)
if _filter := DjangoOrmSpecificationBuilder.build(specification):
queryset = self.django_model.objects.filter(_filter).order_by(order_by)
else:
queryset = self.django_model.objects.all()
queryset = self.django_model.objects.all().order_by(order_by)

if limit:
queryset.offset(offset).limit(limit)

for obj in queryset:
yield self._obj_to_domain(obj.__dict__)
Expand Down
55 changes: 32 additions & 23 deletions fractal/contrib/fastapi/routers/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,44 +111,50 @@ def to_entity(self, **kwargs):


class BasicRestRouterService(DefaultRestRouterService):
def add_entity(
def find_entities(
self,
contract: Contract,
q: str = "",
*,
specification: Specification = None,
**kwargs,
):
try:
UUID(contract.id)
except (ValueError, TypeError):
contract.id = None
_entity = contract.to_entity(
user_id=kwargs.get("sub"), account_id=kwargs.get("account")
)
return self.ingress_service.add(
_entity, str(kwargs.get("sub")), specification=specification
return self.ingress_service.find(
account_id=str(kwargs.get("account")),
q=q,
specification=specification,
)

def find_entities(
def get_entity(
self,
q: str = "",
entity_id: UUID,
*,
specification: Specification = None,
**kwargs,
):
return self.ingress_service.find(
str(kwargs.get("account")), q, specification=specification
return self.ingress_service.get(
entity_id=str(entity_id),
acount_id=str(kwargs.get("account")),
specification=specification,
)

def get_entity(
def add_entity(
self,
entity_id: UUID,
contract: Contract,
*,
specification: Specification = None,
**kwargs,
):
return self.ingress_service.get(
str(entity_id), str(kwargs.get("account")), specification=specification
try:
UUID(contract.id)
except (ValueError, TypeError):
contract.id = None
_entity = contract.to_entity(
user_id=kwargs.get("sub"), account_id=kwargs.get("account")
)
return self.ingress_service.add(
entity=_entity,
user_id=str(kwargs.get("sub")),
specification=specification,
)

def update_entity(
Expand All @@ -170,7 +176,10 @@ def update_entity(
account_id=kwargs.get("account"),
)
return self.ingress_service.update(
str(entity_id), _entity, str(kwargs.get("sub")), specification=specification
entity_id=str(entity_id),
entity=_entity,
user_id=str(kwargs.get("sub")),
specification=specification,
)

def delete_entity(
Expand All @@ -181,9 +190,9 @@ def delete_entity(
**kwargs,
) -> Dict:
self.ingress_service.delete(
str(entity_id),
str(kwargs.get("sub")),
str(kwargs.get("account")),
entity_id=str(entity_id),
user_id=str(kwargs.get("sub")),
account_id=str(kwargs.get("account")),
specification=specification,
)
return {}
Expand Down
53 changes: 21 additions & 32 deletions fractal/contrib/gcp/firestore/repositories.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,12 @@
FirestoreSpecificationBuilder,
)
from fractal_specifications.generic.specification import Specification
from google.cloud import firestore
from google.cloud.firestore_v1 import Client
from google.cloud.firestore_v1 import Client, Query

from fractal import Settings
from fractal.contrib.gcp import SettingsMixin
from fractal.core.exceptions import ObjectNotFoundException
from fractal.core.repositories import Entity, Repository
from fractal.core.repositories.sort_repository_mixin import SortRepositoryMixin


def get_firestore_client(settings: Settings):
Expand Down Expand Up @@ -93,15 +91,33 @@ def find_one(self, specification: Specification) -> Optional[Entity]:
raise self.object_not_found_exception
raise ObjectNotFoundException(f"{self.entity.__name__} not found!")

def find(self, specification: Specification = None) -> Iterator[Entity]:
def find(
self,
specification: Specification = None,
*,
offset: int = 0,
limit: int = 0,
order_by: str = "id",
) -> Iterator[Entity]:
_filter = FirestoreSpecificationBuilder.build(specification)
collection = self.collection
direction = Query.ASCENDING
if order_by.startswith("-"):
order_by = order_by[1:]
direction = Query.DESCENDING
collection = self.collection.order_by(order_by, direction=direction)
if _filter:
if isinstance(_filter, list):
for f in _filter:
collection = collection.where(*f)
else:
collection = collection.where(*_filter)
if limit:
if offset and (last := list(collection.limit(offset).stream())[-1]):
collection = collection.start_after(
{order_by: last.to_dict().get(order_by)}
).limit(limit)
else:
collection = collection.limit(limit)
for doc in collection.stream():
yield self.entity.from_dict(doc.to_dict())

Expand All @@ -123,30 +139,3 @@ def update(self, entity: Entity, *, upsert=False) -> Entity:
doc_ref.set(entity.asdict(skip_types=(date,)))
return entity
return self.add(entity)


class FirestoreSortRepositoryMixin(SortRepositoryMixin[Entity]):
def find_sort(
self, specification: Specification = None, *, order_by: str = "", limit: int = 0
) -> Iterator[Entity]:
_filter = FirestoreSpecificationBuilder.build(specification)
collection = self.collection
if _filter:
if isinstance(_filter, list):
for f in _filter:
collection = collection.where(*f)
else:
collection = collection.where(*_filter)
if order_by:
if reverse := order_by.startswith("-"):
order_by = order_by[1:]
collection = collection.order_by(
order_by,
direction=firestore.Query.DESCENDING
if reverse
else firestore.Query.ASCENDING,
)
if limit:
collection = collection.limit(limit)
for doc in collection.stream():
yield self.entity.from_dict(doc.to_dict())
26 changes: 23 additions & 3 deletions fractal/contrib/mongo/repositories.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,30 @@ def find_one(self, specification: Specification) -> Optional[Entity]:
return self._obj_to_domain(obj)

def find(
self, specification: Specification = None
self,
specification: Specification = None,
*,
offset: int = 0,
limit: int = 0,
order_by: str = "id",
) -> Generator[Entity, None, None]:
for obj in self.collection.find(MongoSpecificationBuilder.build(specification)):
yield self._obj_to_domain(obj)
direction = 1
if order_by.startswith("-"):
order_by = order_by[1:]
direction = -1
if limit:
for obj in (
self.collection.find(MongoSpecificationBuilder.build(specification))
.sort({order_by: direction})
.skip(offset)
.limit(limit)
):
yield self._obj_to_domain(obj)
else:
for obj in self.collection.find(
MongoSpecificationBuilder.build(specification)
).sort({order_by: direction}):
yield self._obj_to_domain(obj)

def is_healthy(self) -> bool:
ok = self.client.server_info().get("ok", False)
Expand Down
70 changes: 57 additions & 13 deletions fractal/contrib/sqlalchemy/repositories.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from sqlalchemy.engine import Engine
from sqlalchemy.exc import ArgumentError, IntegrityError
from sqlalchemy.orm import Mapper, Session, sessionmaker # NOQA
from sqlalchemy.sql.elements import BooleanClauseList

from fractal.core.exceptions import DomainException
from fractal.core.repositories import Entity, Repository
Expand Down Expand Up @@ -217,9 +218,19 @@ def find_one(self, specification: Specification) -> Optional[Entity]:
return self._dao_to_domain(entity)

def find(
self, specification: Optional[Specification] = None
self,
specification: Optional[Specification] = None,
*,
offset: int = 0,
limit: int = 0,
order_by: str = "id",
) -> Generator[Entity, None, None]:
entities = self._find_raw(specification)
entities = self._find_raw(
specification=specification,
offset=offset,
limit=limit,
order_by=order_by,
)

if specification:
entities = filter(lambda i: specification.is_satisfied_by(i), entities)
Expand Down Expand Up @@ -265,25 +276,58 @@ def _find_raw(
specification: Optional[Specification],
*,
entity_dao_class: Optional[SqlAlchemyDao] = None,
offset: int = 0,
limit: int = 0,
order_by: str = "id",
) -> List[Entity]:
_filter = {}
if specification:
_filter = SqlAlchemyOrmSpecificationBuilder.build(specification)
if isinstance(_filter, list):
entities = []
filters = {}
for f in _filter:
entities.extend(
list(
self.session.query(
entity_dao_class or self.entity_dao
).filter_by(**dict(f))
)
filters.update(f)
from sqlalchemy import or_

# TODO move to SqlAlchemyOrmSpecificationBuilder
filters = or_(
*[
getattr(entity_dao_class or self.entity_dao, k) == v
for k, v in filters.items()
]
)
else:
from sqlalchemy import and_

if len(_filter) > 1:
# TODO move to SqlAlchemyOrmSpecificationBuilder
filters = and_(
*[
getattr(entity_dao_class or self.entity_dao, k) == v
for k, v in _filter.items()
]
)
return entities
else:
filters = _filter

if order_by.startswith("-"):
_order_by = getattr(entity_dao_class or self.entity_dao, order_by[1:])
desc = True
else:
return self.session.query(entity_dao_class or self.entity_dao).filter_by(
**dict(_filter)
)
_order_by = getattr(entity_dao_class or self.entity_dao, order_by)
desc = False

ret = self.session.query(entity_dao_class or self.entity_dao)
if type(filters) == dict:
ret = ret.filter_by(**filters)
if type(filters) == BooleanClauseList:
ret = ret.where(filters)
ret = ret.order_by(_order_by.desc() if desc else _order_by)
if limit:
ret = ret.offset(offset)
ret = ret.limit(limit)
return ret
return ret

def is_healthy(self) -> bool:
try:
Expand Down
9 changes: 8 additions & 1 deletion fractal/core/repositories/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,14 @@ def find_one(self, specification: Specification) -> Optional[Entity]:
raise NotImplementedError

@abstractmethod
def find(self, specification: Specification = None) -> Iterator[Entity]:
def find(
self,
specification: Specification = None,
*,
offset: int = 0,
limit: int = 0,
order_by: str = "id",
) -> Iterator[Entity]:
raise NotImplementedError

@abstractmethod
Expand Down
Loading

0 comments on commit 2ef1adf

Please sign in to comment.