Skip to content

Commit

Permalink
Refactored once again
Browse files Browse the repository at this point in the history
It's not perfect just yet. There remains a lot of code in the concrete implementations.
  • Loading branch information
MrMatAP committed Jan 14, 2024
1 parent 2bcd1c6 commit abea460
Show file tree
Hide file tree
Showing 5 changed files with 357 additions and 175 deletions.
15 changes: 15 additions & 0 deletions .idea/dataSources.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

134 changes: 76 additions & 58 deletions src/kaso_mashin/common/base_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,15 @@ class EntityInvariantException(KasoMashinException):
pass


class EntityMaterialisationException(KasoMashinException):
pass


class ORMBase(DeclarativeBase):
"""
ORM base class for persisted entities
"""
id: Mapped[str] = mapped_column(UUID(as_uuid=True).with_variant(String(32), 'sqlite'), primary_key=True)
uid: Mapped[str] = mapped_column(UUID(as_uuid=True).with_variant(String(32), 'sqlite'), primary_key=True)

@abc.abstractmethod
def merge(self, other: typing.Self):
Expand All @@ -49,93 +53,110 @@ def merge(self, other: typing.Self):
T_Model = typing.TypeVar('T_Model', bound=ORMBase)


@dataclasses.dataclass
class Entity(typing.Generic[T_Model]):
class ValueObject(abc.ABC):
"""
A domain entity
A domain value object
"""
# TODO: Owner should be either T_AggregateRoot or T_AsyncAggregateRoot
id: UniqueIdentifier = dataclasses.field(default_factory=lambda: uuid.uuid4())
owner: typing.Optional[typing.Any] = dataclasses.field(default=None)
pass


T_Entity = typing.TypeVar("T_Entity", bound=Entity)
T_ValueObject = typing.TypeVar('T_ValueObject', bound=ValueObject)


class ValueObject(abc.ABC):
class Entity(object):
"""
A domain value object
A domain entity
"""
pass

def __init__(self, owner: 'AggregateRoot', uid: UniqueIdentifier = uuid.uuid4()) -> None:
self._uid = uid
self._owner = owner

T_ValueObject = typing.TypeVar('T_ValueObject', bound=ValueObject)
@property
def uid(self) -> UniqueIdentifier:
return self._uid

def __eq__(self, other: object) -> bool:
return all([
isinstance(other, self.__class__),
self._uid == other._uid, # type: ignore[attr-defined]
self._owner == other._owner # type: ignore[attr-defined]
])


class AsyncAggregateRoot(typing.Generic[T_Entity, T_Model]):
T_Entity = typing.TypeVar("T_Entity", bound=Entity)


class AggregateRoot(typing.Generic[T_Entity, T_Model]):

def __init__(self, model: typing.Type[T_Model], session_maker: async_sessionmaker[AsyncSession]) -> None:
self._repository = AsyncRepository[T_Model](model=model, session_maker=session_maker)
self._identity_map: typing.Dict[UniqueIdentifier, T_Entity] = {}

async def get(self, uid: UniqueIdentifier) -> T_Entity:
model = await self._repository.get_by_id(str(uid))
entity = await self.from_model(model)
if not await self.validate(entity):
async def get(self, uid: UniqueIdentifier, force_reload: bool = False) -> T_Entity:
if not force_reload and uid in self._identity_map:
return self._identity_map[uid]
model = await self._repository.get_by_uid(str(uid))
entity = await self._from_model(model)
if not await self._validate(entity):
raise EntityInvariantException(code=500, msg='Restored entity fails validation')
entity.owner = self
self._identity_map[entity.id] = entity
return self._identity_map[entity.id]
self._identity_map[entity.uid] = entity
return self._identity_map[entity.uid]

async def list(self) -> typing.List[T_Entity]:
async def list(self, force_reload: bool = False) -> typing.List[T_Entity]:
if not force_reload:
return list(self._identity_map.values())
models = await self._repository.list()
entities = [await self.from_model(model) for model in models]
entities = [await self._from_model(model) for model in models]
for entity in entities:
if not await self.validate(entity):
raise EntityInvariantException(code=500, msg='Entity fails validation')
entity.owner = self
self._identity_map.update({e.id: e for e in entities})
if not await self._validate(entity):
raise EntityInvariantException(code=400, msg='Entity fails validation')
self._identity_map.update({e.uid: e for e in entities})
return list(self._identity_map.values())

async def create(self, entity: T_Entity) -> T_Entity:
if not self.validate(entity):
raise EntityInvariantException(code=500, msg='Entity fails validation')
model = await self._repository.create(await self.to_model(entity))
self._identity_map[entity.id] = await self.from_model(model)
self._identity_map[entity.id].owner = self
return self._identity_map[entity.id]

# TODO: Modify may make sense to be moved into the entity
async def modify(self, entity: T_Entity) -> T_Entity:
if entity.id not in self._identity_map:
if entity.uid in self._identity_map:
raise EntityInvariantException(code=400, msg='Entity already exists')
if not self._validate(entity):
raise EntityInvariantException(code=400, msg='Entity fails validation')
model = await self._repository.create(await self._to_model(entity))
self._identity_map[entity.uid] = await self._from_model(model)
return self._identity_map[entity.uid]

# Only methods in the entity should call this
async def modify(self, entity: T_Entity):
if entity.uid not in self._identity_map:
raise EntityInvariantException(code=400, msg='Entity was not created by its aggregate root')
if not self.validate(entity):
if not self._validate(entity):
raise EntityInvariantException(code=400, msg='Entity fails validation')
model = await self._repository.modify(await self.to_model(entity))
self._identity_map[entity.id] = await self.from_model(model)
self._identity_map[entity.id].owner = self
return self._identity_map[entity.id]
await self._repository.modify(await self._to_model(entity))

# TODO: Modify may make sense to be moved into the entity
# An entity should only be removed using this method
async def remove(self, uid: UniqueIdentifier):
if uid not in self._identity_map:
raise EntityInvariantException(code=400, msg='Entity was not created by its aggregate root')
await self._repository.remove(str(uid))
del self._identity_map[uid]

async def validate(self, entity: T_Entity) -> bool:
async def _validate(self, entity: T_Entity) -> bool:
return True

@abc.abstractmethod
async def to_model(self, entity: T_Entity) -> T_Model:
async def _to_model(self, entity: T_Entity) -> T_Model:
pass

@abc.abstractmethod
async def from_model(self, model: T_Model) -> T_Entity:
async def _from_model(self, model: T_Model) -> T_Entity:
pass


T_AsyncAggregateRoot = typing.TypeVar('T_AsyncAggregateRoot', bound=AsyncAggregateRoot)
T_AggregateRoot = typing.TypeVar('T_AggregateRoot', bound=AggregateRoot)


class DiskFormat(enum.StrEnum):
Raw = 'raw'
QCoW2 = 'qcow2'
VDI = 'vdi'


class BinaryScale(enum.StrEnum):
Expand Down Expand Up @@ -168,7 +189,7 @@ def __init__(self,
self._session_maker = session_maker
self._identity_map: typing.Dict[str, T_Model] = {}

async def get_by_id(self, uid: str) -> T_Model:
async def get_by_uid(self, uid: str) -> T_Model:
if uid in self._identity_map:
return self._identity_map[uid]
async with self._session_maker() as session:
Expand All @@ -180,33 +201,28 @@ async def get_by_id(self, uid: str) -> T_Model:

async def list(self) -> typing.List[T_Model]:
async with self._session_maker() as session:
# Alternative implementation
# with await session.stream_scalars(select(self._model_clazz)) as result:
# models = await result.all()
# for model in models:
# self._identity_map[UniqueIdentifier(model.id)] = self._aggregate_root_clazz.deserialise(model)
models = await session.scalars(select(self._model_clazz))
for model in models:
self._identity_map[model.id] = model
self._identity_map[model.uid] = model
# Note: list() is semantically better but mypy complains about an incompatible arg
return [i for i in self._identity_map.values()]

async def create(self, model: T_Model) -> T_Model:
async with self._session_maker() as session:
async with session.begin():
session.add(model)
self._identity_map[model.id] = model
session.add(model)
await session.commit()
self._identity_map[model.uid] = model
return model

async def modify(self, update: T_Model) -> T_Model:
async with self._session_maker() as session:
current = await session.get(self._model_clazz, str(update.id))
current = await session.get(self._model_clazz, str(update.uid))
if current is None:
raise EntityNotFoundException(code=400, msg='No such entity')
current.merge(update)
session.add(current)
await session.commit()
self._identity_map[update.id] = update
self._identity_map[update.uid] = update
return update

async def remove(self, uid: str) -> None:
Expand All @@ -216,3 +232,5 @@ async def remove(self, uid: str) -> None:
await session.delete(model)
await session.commit()
del self._identity_map[uid]


Loading

0 comments on commit abea460

Please sign in to comment.