Skip to content

Commit

Permalink
Promoted the base types
Browse files Browse the repository at this point in the history
  • Loading branch information
MrMatAP committed Jan 13, 2024
1 parent a237c96 commit 2bcd1c6
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 188 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import uuid

from sqlalchemy import UUID, String, select
from sqlalchemy.orm import DeclarativeBase, sessionmaker, Session, Mapped, mapped_column
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
from sqlalchemy.ext.asyncio import async_sessionmaker, AsyncSession


Expand Down Expand Up @@ -76,104 +76,68 @@ class AsyncAggregateRoot(typing.Generic[T_Entity, T_Model]):

def __init__(self, model: typing.Type[T_Model], session_maker: async_sessionmaker[AsyncSession]) -> None:
self._repository = AsyncRepository[T_Model](model=model, session_maker=session_maker)
self._identity_map: typing.Dict[UniqueIdentifier, T_Entity] = {}

async def get(self, uid: UniqueIdentifier) -> T_Entity:
model = await self._repository.get_by_id(str(uid))
entity = self.deserialise(model)
if not self.validate(entity):
raise EntityInvariantException(code=500, msg='Entity fails validation')
return entity
entity = await self.from_model(model)
if not await self.validate(entity):
raise EntityInvariantException(code=500, msg='Restored entity fails validation')
entity.owner = self
self._identity_map[entity.id] = entity
return self._identity_map[entity.id]

async def list(self) -> typing.List[T_Entity]:
models = await self._repository.list()
entities = [self.deserialise(model) for model in models]
bad_entities = [e for e in entities if not self.validate(e)]
if len(bad_entities) > 0:
raise EntityInvariantException(code=500, msg='Entity fails validation')
return entities
entities = [await self.from_model(model) for model in models]
for entity in entities:
if not await self.validate(entity):
raise EntityInvariantException(code=500, msg='Entity fails validation')
entity.owner = self
self._identity_map.update({e.id: e for e in entities})
return list(self._identity_map.values())

async def create(self, entity: T_Entity) -> T_Entity:
if not self.validate(entity):
raise EntityInvariantException(code=500, msg='Entity fails validation')
model = await self._repository.create(self.serialise(entity))
return self.deserialise(model)
model = await self._repository.create(await self.to_model(entity))
self._identity_map[entity.id] = await self.from_model(model)
self._identity_map[entity.id].owner = self
return self._identity_map[entity.id]

# TODO: Modify may make sense to be moved into the entity
async def modify(self, entity: T_Entity) -> T_Entity:
if entity.id not in self._identity_map:
raise EntityInvariantException(code=400, msg='Entity was not created by its aggregate root')
if not self.validate(entity):
raise EntityInvariantException(code=500, msg='Entity fails validation')
model = await self._repository.modify(self.serialise(entity))
return self.deserialise(model)
raise EntityInvariantException(code=400, msg='Entity fails validation')
model = await self._repository.modify(await self.to_model(entity))
self._identity_map[entity.id] = await self.from_model(model)
self._identity_map[entity.id].owner = self
return self._identity_map[entity.id]

# TODO: Modify may make sense to be moved into the entity
async def remove(self, uid: UniqueIdentifier):
return await self._repository.remove(str(uid))
if uid not in self._identity_map:
raise EntityInvariantException(code=400, msg='Entity was not created by its aggregate root')
await self._repository.remove(str(uid))
del self._identity_map[uid]

@abc.abstractmethod
def validate(self, entity: T_Entity) -> bool:
pass
async def validate(self, entity: T_Entity) -> bool:
return True

@abc.abstractmethod
def serialise(self, entity: T_Entity) -> T_Model:
async def to_model(self, entity: T_Entity) -> T_Model:
pass

@abc.abstractmethod
def deserialise(self, model: T_Model) -> T_Entity:
async def from_model(self, model: T_Model) -> T_Entity:
pass


T_AsyncAggregateRoot = typing.TypeVar('T_AsyncAggregateRoot', bound=AsyncAggregateRoot)


class AggregateRoot(typing.Generic[T_Entity, T_Model]):

def __init__(self, model: typing.Type[T_Model], session_maker: sessionmaker[Session]) -> None:
self._repository = Repository[T_Model](model=model, session_maker=session_maker)

def get(self, uid: UniqueIdentifier) -> T_Entity:
model = self._repository.get_by_id(str(uid))
entity = self.deserialise(model)
if not self.validate(entity):
raise EntityInvariantException(code=500, msg='Entity fails validation')
return entity

def list(self) -> typing.List[T_Entity]:
models = self._repository.list()
entities = [self.deserialise(model) for model in models]
bad_entities = [e for e in entities if not self.validate(e)]
if len(bad_entities) > 0:
raise EntityInvariantException(code=500, msg='Entity fails validation')
return entities

def create(self, entity: T_Entity) -> T_Entity:
if not self.validate(entity):
raise EntityInvariantException(code=500, msg='Entity fails validation')
model = self._repository.create(self.serialise(entity))
return self.deserialise(model)

def modify(self, entity: T_Entity) -> T_Entity:
if not self.validate(entity):
raise EntityInvariantException(code=500, msg='Entity fails validation')
model = self._repository.modify(self.serialise(entity))
return self.deserialise(model)

def remove(self, uid: UniqueIdentifier):
return self._repository.remove(str(uid))

@abc.abstractmethod
def validate(self, entity: T_Entity) -> bool:
pass

@abc.abstractmethod
def serialise(self, entity: T_Entity) -> T_Model:
pass

@abc.abstractmethod
def deserialise(self, model: T_Model) -> T_Entity:
pass


T_AggregateRoot = typing.TypeVar('T_AggregateRoot', bound=AggregateRoot)


class BinaryScale(enum.StrEnum):
k = 'Kilobytes'
M = 'Megabytes'
Expand Down Expand Up @@ -252,56 +216,3 @@ async def remove(self, uid: str) -> None:
await session.delete(model)
await session.commit()
del self._identity_map[uid]


class Repository(typing.Generic[T_Model]):

def __init__(self,
model: typing.Type[T_Model],
session_maker: sessionmaker):
self._model_clazz = model
self._session_maker = session_maker
self._identity_map: typing.Dict[str, T_Model] = {}

def get_by_id(self, uid: str) -> T_Model:
if uid in self._identity_map:
return self._identity_map[uid]
with self._session_maker() as session:
model = session.get(self._model_clazz, str(uid))
if model is None:
raise EntityNotFoundException(code=400, msg='No such entity')
self._identity_map[uid] = model
return self._identity_map[uid]

def list(self) -> typing.List[T_Model]:
with self._session_maker() as session:
for model in session.scalars(select(self._model_clazz)).all():
self._identity_map[model.id] = model
# Note: list() is semantically better but mypy complains about an incompatible arg
return [i for i in self._identity_map.values()]

def create(self, model: T_Model) -> T_Model:
with self._session_maker() as session:
with session.begin():
session.add(model)
self._identity_map[model.id] = model
return model

def modify(self, update: T_Model) -> T_Model:
with self._session_maker() as session:
current = session.get(self._model_clazz, str(update.id))
if current is None:
raise EntityNotFoundException(code=400, msg='No such entity')
current.merge(update)
session.add(current)
session.commit()
self._identity_map[update.id] = update
return update

def remove(self, uid: str) -> None:
with self._session_maker() as session:
model = session.get(self._model_clazz, str(uid))
if model is not None:
session.delete(model)
session.commit()
del self._identity_map[uid]
31 changes: 4 additions & 27 deletions src/kaso_mashin/common/generics/disks.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from sqlalchemy import String, Integer, Enum
from sqlalchemy.orm import Mapped, mapped_column

from kaso_mashin.common.generics.base_types import Entity, BinarySizedValue, BinaryScale, KasoMashinException, ORMBase, \
AggregateRoot, T_Entity, UniqueIdentifier, AsyncAggregateRoot
from kaso_mashin.common.base_types import Entity, BinarySizedValue, BinaryScale, KasoMashinException, ORMBase, \
T_Entity, UniqueIdentifier, AsyncAggregateRoot
from kaso_mashin.common.generics.images import ImageEntity


Expand Down Expand Up @@ -96,39 +96,16 @@ def remove(self):
os.unlink(self.path)


class DiskAggregateRoot(AggregateRoot[DiskEntity, DiskModel]):

def validate(self, entity: T_Entity) -> bool:
return True

def serialise(self, entity: DiskEntity) -> DiskModel:
return DiskModel(id=str(entity.id),
name=entity.name,
path=str(entity.path),
size=entity.size.value,
size_scale=entity.size.scale)

def deserialise(self, model: DiskModel) -> DiskEntity:
return DiskEntity(owner=self,
id=UniqueIdentifier(model.id),
name=model.name,
path=pathlib.Path(model.path),
size=BinarySizedValue(model.size, BinaryScale(model.size_scale)))


class AsyncDiskAggregateRoot(AsyncAggregateRoot[DiskEntity, DiskModel]):

def validate(self, entity: T_Entity) -> bool:
return True

def serialise(self, entity: DiskEntity) -> DiskModel:
async def to_model(self, entity: DiskEntity) -> DiskModel:
return DiskModel(id=str(entity.id),
name=entity.name,
path=str(entity.path),
size=entity.size.value,
size_scale=entity.size.scale)

def deserialise(self, model: DiskModel) -> DiskEntity:
async def from_model(self, model: DiskModel) -> DiskEntity:
return DiskEntity(owner=self,
id=UniqueIdentifier(model.id),
name=model.name,
Expand Down
13 changes: 5 additions & 8 deletions src/kaso_mashin/common/generics/images.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from sqlalchemy import String, Integer, Enum
from sqlalchemy.orm import Mapped, mapped_column

from kaso_mashin.common.generics.base_types import Entity, BinarySizedValue, BinaryScale, KasoMashinException, ORMBase, \
AggregateRoot, T_Entity, UniqueIdentifier
from kaso_mashin.common.base_types import Entity, BinarySizedValue, BinaryScale, KasoMashinException, ORMBase, \
AsyncAggregateRoot, UniqueIdentifier


class ImageException(KasoMashinException):
Expand Down Expand Up @@ -47,12 +47,9 @@ class ImageEntity(Entity[ImageModel]):
min_disk: BinarySizedValue = dataclasses.field(default_factory=lambda: BinarySizedValue(0, BinaryScale.G))


class ImageAggregateRoot(AggregateRoot[ImageEntity, ImageModel]):
class AsyncImageAggregateRoot(AsyncAggregateRoot[ImageEntity, ImageModel]):

def validate(self, entity: T_Entity) -> bool:
return True

def serialise(self, entity: ImageEntity) -> ImageModel:
async def to_model(self, entity: ImageEntity) -> ImageModel:
return ImageModel(id=str(entity.id),
name=entity.name,
path=str(entity.path),
Expand All @@ -62,7 +59,7 @@ def serialise(self, entity: ImageEntity) -> ImageModel:
min_disk=entity.min_disk.value,
min_disk_scale=entity.min_disk.scale)

def deserialise(self, model: ImageModel) -> ImageEntity:
async def from_model(self, model: ImageModel) -> ImageEntity:
return ImageEntity(owner=self,
id=UniqueIdentifier(model.id),
name=model.name,
Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from kaso_mashin.common.model import (
IdentityKind, IdentityModel)

from kaso_mashin.common.generics.base_types import ORMBase
from kaso_mashin.common.base_types import ORMBase

KasoTestContext = collections.namedtuple('KasoTestContext', 'runtime client')
KasoIdentity = collections.namedtuple('KasoIdentity',
Expand Down
32 changes: 5 additions & 27 deletions tests/test_generics.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,17 @@
import pytest
import pathlib

from kaso_mashin.common.generics.base_types import BinaryScale, BinarySizedValue
from kaso_mashin.common.generics.disks import DiskEntity, DiskModel, DiskAggregateRoot, AsyncDiskAggregateRoot


def test_disks(generics_session_maker):
aggregate_root = DiskAggregateRoot(model=DiskModel, session_maker=generics_session_maker)

assert aggregate_root.list() == []
try:
disk = aggregate_root.create(DiskEntity(name='Test Disk',
path=pathlib.Path(__file__).parent / 'build' / 'test.qcow2',
size=BinarySizedValue(1, BinaryScale.G)))
loaded = aggregate_root.get(disk.id)
assert disk == loaded
disk.size = BinarySizedValue(2, scale=BinaryScale.G)
updated = aggregate_root.modify(disk)
assert disk == updated
listed = aggregate_root.list()
assert len(listed) == 1
assert disk == listed[0]
finally:
aggregate_root.remove(disk.id)
assert len(aggregate_root.list()) == 0
assert not disk.path.exists()
from kaso_mashin.common.base_types import BinaryScale, BinarySizedValue
from kaso_mashin.common.generics.disks import DiskEntity, DiskModel, AsyncDiskAggregateRoot


@pytest.mark.asyncio(scope='module')
async def test_async_disks(generics_async_session_maker):
aggregate_root = AsyncDiskAggregateRoot(model=DiskModel, session_maker=generics_async_session_maker)
disk = await aggregate_root.create(DiskEntity(name='Test Disk',
path=pathlib.Path(__file__).parent / 'build' / 'test.qcow2',
size=BinarySizedValue(1, BinaryScale.G)))
try:
disk = await aggregate_root.create(DiskEntity(name='Test Disk',
path=pathlib.Path(__file__).parent / 'build' / 'test.qcow2',
size=BinarySizedValue(1, BinaryScale.G)))
loaded = await aggregate_root.get(disk.id)
assert disk == loaded
disk.size = BinarySizedValue(2, scale=BinaryScale.G)
Expand Down

0 comments on commit 2bcd1c6

Please sign in to comment.