Skip to content

Commit

Permalink
Refactored disks for DDD
Browse files Browse the repository at this point in the history
Remains unoptimised though
  • Loading branch information
MrMatAP committed Feb 25, 2024
1 parent d14beeb commit d58458f
Show file tree
Hide file tree
Showing 9 changed files with 245 additions and 96 deletions.
1 change: 1 addition & 0 deletions dev-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ mypy==1.8.0 # MIT
pytest==7.4.3 # GPL-2.0-or-later
pytest-cov==4.1.0 # MIT
types-PyYAML==6.0.12.12 # Apache 2.0
types-aiofiles==23.2.0.20240106 # Apache 2.0

# Runtime requirements

Expand Down
11 changes: 10 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,19 @@ kaso-server = "kaso_mashin.server.run:main"
# in addopts
[tool.pytest.ini_options]
minversion = "6.0"
#addopts = "--cov=kaso_mashin --cov-report=term --cov-report=xml:build/coverage.xml --junit-xml=build/junit.xml"
addopts = "--cov=kaso_mashin --cov-report=term --cov-report=xml:build/coverage.xml --junit-xml=build/junit.xml"
testpaths = ["tests"]
junit_family = "xunit2"
log_cli = 1
log_cli_level = "INFO"
log_cli_format = "%(asctime)s [%(levelname)8s] %(message)s (%(filename)s:%(lineno)s)"
log_cli_date_format="%Y-%m-%d %H:%M:%S"
asyncio_mode = "auto"

[tool.mypy]
plugins = [ 'pydantic.mypy' ]

[tool.pydantic-mypy]
init_forbid_extra = true
init_typed = true
warn_required_dynamic_aliases = true
90 changes: 70 additions & 20 deletions src/kaso_mashin/common/base_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
from sqlalchemy.ext.asyncio import async_sessionmaker, AsyncSession

from pydantic import BaseModel, ConfigDict, Field


class KasoMashinException(Exception):

Expand Down Expand Up @@ -53,6 +55,17 @@ def merge(self, other: typing.Self):
T_Model = typing.TypeVar('T_Model', bound=ORMBase)


class SchemaBase(BaseModel):
"""
Schema base class for serialised entities
"""
model_config = ConfigDict(from_attributes=True)
uid: UniqueIdentifier = Field(description='The unique identifier', examples=['b430727e-2491-4184-bb4f-c7d6d213e093'])


T_Schema = typing.TypeVar('T_Schema', bound=SchemaBase)


class ValueObject(abc.ABC):
"""
A domain value object
Expand All @@ -63,31 +76,53 @@ class ValueObject(abc.ABC):
T_ValueObject = typing.TypeVar('T_ValueObject', bound=ValueObject)


class Entity(object):
class Entity:
"""
A domain entity
A domain base entity
All domain entities have a unique identity and a corresponding aggregate root as their owner
"""

def __init__(self, owner: 'AggregateRoot', uid: UniqueIdentifier = uuid.uuid4()) -> None:
self._uid = uid
def __init__(self, owner: 'AggregateRoot') -> None:
self._uid = uuid.uuid4()
self._owner = owner

@property
def uid(self) -> UniqueIdentifier:
return self._uid

@property
def owner(self) -> 'AggregateRoot':
return self._owner

@abc.abstractmethod
def schema_get(self):
pass

@staticmethod
@abc.abstractmethod
async def schema_create(owner, schema):
pass

@abc.abstractmethod
async def schema_modify(self, schema) -> 'Entity':
pass

def __eq__(self, other: object) -> bool:
return all([
isinstance(other, self.__class__),
self._uid == other._uid, # type: ignore[attr-defined]
self._owner == other._owner # type: ignore[attr-defined]
self._uid == other.uid, # type: ignore[attr-defined]
self._owner == other.owner # type: ignore[attr-defined]
])

def __repr__(self) -> str:
return f'<Entity(uid={self._uid})'


T_Entity = typing.TypeVar("T_Entity", bound=Entity)


class AggregateRoot(typing.Generic[T_Entity, T_Model]):
class AggregateRoot(typing.Generic[T_Entity, T_Model, T_Schema]):

def __init__(self, model: typing.Type[T_Model], session_maker: async_sessionmaker[AsyncSession]) -> None:
self._repository = AsyncRepository[T_Model](model=model, session_maker=session_maker)
Expand All @@ -97,8 +132,8 @@ async def get(self, uid: UniqueIdentifier, force_reload: bool = False) -> T_Enti
if not force_reload and uid in self._identity_map:
return self._identity_map[uid]
model = await self._repository.get_by_uid(str(uid))
entity = await self._from_model(model)
if not await self._validate(entity):
entity = self._from_model(model)
if not self.validate(entity):
raise EntityInvariantException(code=500, msg='Restored entity fails validation')
self._identity_map[entity.uid] = entity
return self._identity_map[entity.uid]
Expand All @@ -107,29 +142,29 @@ async def list(self, force_reload: bool = False) -> typing.List[T_Entity]:
if not force_reload:
return list(self._identity_map.values())
models = await self._repository.list()
entities = [await self._from_model(model) for model in models]
entities = [self._from_model(model) for model in models]
for entity in entities:
if not await self._validate(entity):
if not self.validate(entity):
raise EntityInvariantException(code=400, msg='Entity fails validation')
self._identity_map.update({e.uid: e for e in entities})
return list(self._identity_map.values())

async def create(self, entity: T_Entity) -> T_Entity:
if entity.uid in self._identity_map:
raise EntityInvariantException(code=400, msg='Entity already exists')
if not self._validate(entity):
if not self.validate(entity):
raise EntityInvariantException(code=400, msg='Entity fails validation')
model = await self._repository.create(await self._to_model(entity))
self._identity_map[entity.uid] = await self._from_model(model)
model = await self._repository.create(self._to_model(entity))
self._identity_map[entity.uid] = self._from_model(model)
return self._identity_map[entity.uid]

# Only methods in the entity should call this
async def modify(self, entity: T_Entity):
if entity.uid not in self._identity_map:
raise EntityInvariantException(code=400, msg='Entity was not created by its aggregate root')
if not self._validate(entity):
if not self.validate(entity):
raise EntityInvariantException(code=400, msg='Entity fails validation')
await self._repository.modify(await self._to_model(entity))
await self._repository.modify(self._to_model(entity))

# An entity should only be removed using this method
async def remove(self, uid: UniqueIdentifier):
Expand All @@ -138,15 +173,22 @@ async def remove(self, uid: UniqueIdentifier):
await self._repository.remove(str(uid))
del self._identity_map[uid]

async def _validate(self, entity: T_Entity) -> bool:
return True
def validate(self, entity: T_Entity) -> bool:
return all([
entity is not None,
isinstance(entity.uid, UniqueIdentifier)
])

@abc.abstractmethod
def _to_model(self, entity: T_Entity) -> T_Model:
pass

@abc.abstractmethod
async def _to_model(self, entity: T_Entity) -> T_Model:
def _from_model(self, model: T_Model) -> T_Entity:
pass

@abc.abstractmethod
async def _from_model(self, model: T_Model) -> T_Entity:
async def list_schema(self) -> typing.List[T_Schema]:
pass


Expand Down Expand Up @@ -179,6 +221,14 @@ class BinarySizedValue(ValueObject):
def __str__(self):
return f'{self.value}{self.scale.name}'

def __repr__(self):
return f'<BinarySizedValue(value={self.value}, scale={self.scale.name})>'


class BinarySizedValueSchema(SchemaBase):
value: int = Field(description="The value", examples=[2, 4, 8])
scale: BinaryScale = Field(description="The binary scale", examples=[BinaryScale.M, BinaryScale.G, BinaryScale.T])


class AsyncRepository(typing.Generic[T_Model]):

Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
import typing
import pathlib
import subprocess
import uuid

from pydantic import Field, BaseModel

from sqlalchemy import String, Integer, Enum
from sqlalchemy.orm import Mapped, mapped_column

from kaso_mashin.common.base_types import (
KasoMashinException,
ORMBase,
ORMBase, SchemaBase,
Entity, AggregateRoot,
T_Schema,
BinarySizedValue,
UniqueIdentifier, BinaryScale, DiskFormat)

Expand Down Expand Up @@ -36,19 +39,44 @@ def merge(self, other: 'DiskModel'):
self.format = other.format


class DiskListSchema(SchemaBase):
name: str = Field(description='The disk name', examples=['root', 'data-1', 'data-2'])


class DiskGetSchema(SchemaBase):
name: str = Field(description='Disk name',
examples=['root', 'data-1', 'data-2'])
path: pathlib.Path = Field(description='Path of the disk image on the local filesystem',
examples=['/var/kaso/instances/root.qcow2'])
size: BinarySizedValue = Field(description='Disk size')
disk_format: DiskFormat = Field(description='Disk image file format')


class DiskCreateSchema(BaseModel):
name: str = Field(description='Disk name',
examples=['root', 'data-1', 'data-2'])
path: pathlib.Path = Field(description='Path of the disk image on the local filesystem',
examples=['/var/kaso/instances/root.qcow2'])
size: BinarySizedValue = Field(description='Disk size')
disk_format: DiskFormat = Field(description='Disk image file format')


class DiskModifySchema(BaseModel):
size: BinarySizedValue = Field(description='Disk size')


class DiskEntity(Entity):
"""
Domain model entity for a disk
"""

def __init__(self,
owner: 'AggregateRoot',
owner: 'DiskAggregateRoot',
name: str,
path: pathlib.Path,
size: BinarySizedValue = BinarySizedValue(2, BinaryScale.G),
disk_format: DiskFormat = DiskFormat.Raw,
uid: UniqueIdentifier = uuid.uuid4()) -> None:
super().__init__(owner=owner, uid=uid)
disk_format: DiskFormat = DiskFormat.Raw) -> None:
super().__init__(owner=owner)
self._name = name
self._path = path
self._size = size
Expand All @@ -70,16 +98,23 @@ def size(self) -> BinarySizedValue:
def disk_format(self) -> DiskFormat:
return self._disk_format

def __eq__(self, other: 'DiskEntity') -> bool: # type: ignore[override]
def __eq__(self, other: 'DiskEntity') -> bool: # type: ignore[override]
return all([
super().__eq__(other),
self._name == other.name,
self._path == other.path,
self._size == other.size,
self._disk_format == other.disk_format])

def __repr__(self) -> str:
return (f'<DiskEntity(uid={self.uid}, '
f'name={self.name}, '
f'path={self.path}, '
f'size={self.size}, '
f'disk_format={self.disk_format})>')

@staticmethod
async def create(owner: 'AggregateRoot',
async def create(owner: 'DiskAggregateRoot',
name: str,
path: pathlib.Path,
size: BinarySizedValue = BinarySizedValue(2, BinaryScale.G),
Expand All @@ -104,6 +139,22 @@ async def create(owner: 'AggregateRoot',
path.unlink(missing_ok=True)
raise DiskException(code=500, msg=f'Failed to create disk: {e.output}') from e

@staticmethod
async def schema_create(owner: 'DiskAggregateRoot', schema: DiskCreateSchema) -> 'DiskEntity':
return await DiskEntity.create(owner=owner,
name=schema.name,
path=pathlib.Path(schema.path),
size=schema.size,
disk_format=schema.disk_format)

def schema_get(self) -> DiskGetSchema:
return DiskGetSchema.model_validate(self)

async def schema_modify(self, schema: DiskModifySchema) -> 'DiskEntity':
if schema.size != self.size:
await self.resize(schema.size)
return self

async def resize(self, value: BinarySizedValue):
try:
subprocess.run(['/opt/homebrew/bin/qemu-img',
Expand All @@ -122,20 +173,31 @@ async def remove(self):
await self._owner.remove(self.uid)


class DiskAggregateRoot(AggregateRoot[DiskEntity, DiskModel]):
class DiskAggregateRoot(AggregateRoot[DiskEntity, DiskModel, DiskListSchema]):

def _validate(self, entity: DiskEntity) -> bool:
return all([
super().validate(entity),
entity.path.exists()
])

async def _to_model(self, entity: DiskEntity) -> DiskModel:
def _to_model(self, entity: DiskEntity) -> DiskModel:
return DiskModel(uid=str(entity.uid),
name=entity.name,
path=str(entity.path),
size=entity.size.value,
size_scale=entity.size.scale,
format=str(entity.disk_format))

async def _from_model(self, model: DiskModel) -> DiskEntity:
return DiskEntity(owner=self,
uid=UniqueIdentifier(model.uid),
name=model.name,
path=pathlib.Path(model.path),
size=BinarySizedValue(model.size, BinaryScale(model.size_scale)),
disk_format=DiskFormat(model.format))
def _from_model(self, model: DiskModel) -> 'DiskEntity':
entity = DiskEntity(owner=self,
name=model.name,
path=pathlib.Path(model.path),
size=BinarySizedValue(model.size, BinaryScale(model.size_scale)),
disk_format=DiskFormat(model.format))
entity._uid = UniqueIdentifier(model.uid)
return entity

async def list_schema(self) -> typing.List[DiskListSchema]:
disks = await self.list(force_reload=True)
return [DiskListSchema.model_validate(disk) for disk in disks]
Loading

0 comments on commit d58458f

Please sign in to comment.