Skip to content

Commit

Permalink
sqlalchemy
Browse files Browse the repository at this point in the history
  • Loading branch information
douwevandermeij committed Jul 5, 2021
1 parent e8a5a51 commit 5981db6
Show file tree
Hide file tree
Showing 5 changed files with 221 additions and 2 deletions.
2 changes: 1 addition & 1 deletion fractal/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Fractal is a scaffolding toolkit for building SOLID logic for your Python applications."""

__version__ = "0.1.9"
__version__ = "0.1.10"

from abc import ABC

Expand Down
Empty file.
175 changes: 175 additions & 0 deletions fractal/contrib/sqlalchemy/repositories.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
from __future__ import annotations

import logging
from abc import ABC, abstractmethod
from typing import Dict, Generator, Optional, TypeVar, Generic

from fractal.core.specifications.id_specification import IdSpecification
from sqlalchemy import MetaData, Table, create_engine
from sqlalchemy.engine import Engine
from sqlalchemy.exc import IntegrityError, InterfaceError
from sqlalchemy.orm import Mapper, Session, sessionmaker

from fractal.contrib.sqlalchemy.specifications import SqlAlchemyOrmSpecificationBuilder
from fractal.core.repositories import Entity, Repository
from fractal.core.specifications.generic.specification import Specification

EntityDao = TypeVar("EntityDao")


class SqlAlchemyDao(ABC):
@staticmethod
@abstractmethod
def mapper(meta: MetaData, foreign_keys: Dict[str, Mapper]) -> Mapper:
raise NotImplementedError

@staticmethod
@abstractmethod
def table(meta: MetaData) -> Table:
raise NotImplementedError


class DaoMapper(ABC):
instance = None

def __new__(cls, *args, **kwargs):
if not isinstance(cls.instance, cls):
cls.instance = object.__new__(cls)
cls.instance.done = False
return cls.instance

def start_mappers(self, engine: Engine):
if not self.done:
meta = MetaData()
self.application_mappers(meta)
meta.create_all(engine)
self.done = True

@abstractmethod
def application_mappers(self, meta: MetaData):
raise NotImplementedError


class AbstractUnitOfWork(ABC):
def __enter__(self) -> AbstractUnitOfWork:
return self

def __exit__(self, *args):
self.rollback()

@abstractmethod
def commit(self):
raise NotImplementedError

@abstractmethod
def rollback(self):
raise NotImplementedError


class SqlAlchemyUnitOfWork(AbstractUnitOfWork):
def __init__(self):
self.session_factory = None

def __enter__(self) -> AbstractUnitOfWork:
self.session = self.session_factory() # type: Session
return super().__enter__()

def __exit__(self, *args):
super().__exit__(*args)
self.session.close()

def commit(self):
self.session.commit()

def rollback(self):
self.session.rollback()


class SqlAlchemyRepositoryMixin(Generic[Entity, EntityDao], Repository[Entity], SqlAlchemyUnitOfWork):
entity = Entity
entity_dao = EntityDao
application_mapper = DaoMapper

def __init__(self, connection_str: str):
super().__init__()

self.connection_str = connection_str
engine = create_engine(
self.connection_str,
)

self.application_mapper().start_mappers(engine)

self.session_factory = sessionmaker(
bind=engine,
expire_on_commit=False,
)

def add(self, entity: Entity) -> Entity:
entity_dao = self.entity_dao.from_domain(entity)
with self:
try:
self.session.add(entity_dao)
self.commit()
except IntegrityError:
raise
return entity

def update(self, entity: Entity, upsert=False) -> Entity:
self.remove_one(IdSpecification(entity.id))
return self.add(entity)

def remove_one(self, specification: Specification):
if entity := self._find_one_raw(specification):
self.session.delete(entity)
self.commit()

def find_one(self, specification: Specification) -> Optional[Entity]:
entity = self._find_one_raw(specification)
if entity:
return self.entity(**entity.__dict__)

def _find_one_raw(self, specification: Specification) -> Optional[Entity]:
_filter = SqlAlchemyOrmSpecificationBuilder.build(specification)
if isinstance(_filter, list):
entities = []
for f in _filter:
entities = self.session.query(self.entity_dao).filter_by(**dict(f))
try:
entities = list(self.session.query(self.entity_dao).filter_by(**dict(f)))
except InterfaceError:
pass
else:
break
else:
entities = self.session.query(self.entity_dao).filter_by(**dict(_filter))

for entity in filter(
lambda i: specification.is_satisfied_by(i), entities
):
return entity

def find(
self, specification: Optional[Specification] = None
) -> Generator[Entity, None, None]:
with self:
entities = self.session.query(self.entity_dao).all()

if specification:
entities = filter(
lambda i: specification.is_satisfied_by(i), entities.values()
)
for entity in entities:
d = entity.__dict__
if "_sa_instance_state" in d:
del d["_sa_instance_state"]
yield self.entity(**d)

def is_healthy(self) -> bool:
try:
with self:
self.session.execute("SELECT 1")
except Exception as e:
logging.exception(f"Database is unhealthy! {e}")
return False
return True
44 changes: 44 additions & 0 deletions fractal/contrib/sqlalchemy/specifications.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from typing import Collection, Optional

from fractal.core.exceptions import DomainException
from fractal.core.specifications.generic.collections import (
AndSpecification,
OrSpecification,
)
from fractal.core.specifications.generic.operators import (
EqualsSpecification,
)
from fractal.core.specifications.generic.specification import Specification


class SpecificationNotMappedToSqlAlchemyOrm(DomainException):
code = "SPECIFICATION_NOT_MAPPED_TO_SLQALCHEMY_ORM"
status_code = 500


class SqlAlchemyOrmSpecificationBuilder:
@staticmethod
def build(specification: Specification = None) -> Optional[Collection]:
if specification is None:
return None
elif isinstance(specification, OrSpecification):
return [
SqlAlchemyOrmSpecificationBuilder.build(spec)
for spec in specification.to_collection()
]
elif isinstance(specification, AndSpecification):
return {
k: v
for spec in specification.to_collection()
if (i := SqlAlchemyOrmSpecificationBuilder.build(spec))
for k, v in dict(i).items()
if isinstance(i, dict)
}
elif isinstance(specification, EqualsSpecification):
return {specification.field: specification.value}
elif isinstance(specification.to_collection(), dict):
for key, value in dict(specification.to_collection()).items():
return {key: value}
raise SpecificationNotMappedToSqlAlchemyOrm(
f"Specification '{specification}' not mapped to SqlAlchemy Orm query."
)
2 changes: 1 addition & 1 deletion fractal/core/utils/application_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@ def load(self):

self.load_internal_services()
self.load_repositories()
self.load_ingress_services()
self.load_egress_services()
self.event_publisher = EventPublisher(self.load_event_projectors())
self.load_command_bus()
self.load_ingress_services()

for repository in self.repositories:
repository.is_healthy()
Expand Down

0 comments on commit 5981db6

Please sign in to comment.