Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better abstract "group" types for Instrument and Union #103

Merged
merged 3 commits into from
Aug 25, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
230 changes: 114 additions & 116 deletions src/ome_autogen.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,99 +91,61 @@ def __post_init__(self) -> None:
"BinData/Length": Override(type_="int"),
# FIXME: hard-coded subclass lists
"Instrument/LightSourceGroup": Override(
type_="List[LightSource]",
type_="List[LightSourceGroupType]",
default="Field(default_factory=list)",
imports="""
from typing import Dict, Union, Any
from pydantic import validator
from .light_source import LightSource
from .laser import Laser
from .arc import Arc
from .filament import Filament
from .light_emitting_diode import LightEmittingDiode
from .generic_excitation_source import GenericExcitationSource

_light_source_types: Dict[str, type] = {
"laser": Laser,
"arc": Arc,
"filament": Filament,
"light_emitting_diode": LightEmittingDiode,
"generic_excitation_source": GenericExcitationSource,
}
from pydantic import root_validator
from .light_source_group import LightSourceGroupType
""",
body="""
@validator("light_source_group", pre=True, each_item=True)
def validate_light_source_group(
cls, value: Union[LightSource, Dict[Any, Any]]
) -> LightSource:
if isinstance(value, LightSource):
return value
elif isinstance(value, dict):
try:
_type = value.pop("_type")
except KeyError:
raise ValueError(
"dict initialization requires _type"
) from None
try:
light_source_cls = _light_source_types[_type]
except KeyError:
raise ValueError(
f"unknown LightSource type '{_type}'"
) from None
return light_source_cls(**value)
else:
raise ValueError("invalid type for light_source_group values")
@root_validator(pre=True)
def _root(cls, value: Dict[str, Any]):
light_sources = {i.snake_name() for i in LightSourceGroupType.__args__} # type: ignore
lights = []
for key in list(value):
kind = {"kind": key}
if key in light_sources:
val = value.pop(key)
if isinstance(val, dict):
lights.append({**val, **kind})
elif isinstance(val, list):
lights.extend({**v, **kind} for v in val)
if lights:
value.setdefault("light_source_group", [])
value["light_source_group"].extend(lights)
return value
""",
),
"ROI/Union": Override(
type_="List[Shape]",
type_="List[ShapeGroupType]",
default="Field(default_factory=list)",
imports="""
from typing import Dict, Union, Any
from typing import Dict, Union, Any, Sequence, Iterator
from pydantic import validator
from .shape import Shape
from .point import Point
from .line import Line
from .rectangle import Rectangle
from .ellipse import Ellipse
from .polyline import Polyline
from .polygon import Polygon
from .mask import Mask
from .label import Label

_shape_types: Dict[str, type] = {
"point": Point,
"line": Line,
"rectangle": Rectangle,
"ellipse": Ellipse,
"polyline": Polyline,
"polygon": Polygon,
"mask": Mask,
"label": Label,
}

from .annotation_ref import AnnotationRef
from .shape_group import ShapeGroupType
from .simple_types import ROIID
""",
body="""
@validator("union", pre=True, each_item=True)
def validate_union(
cls, value: Union[Shape, Dict[Any, Any]]
) -> Shape:
if isinstance(value, Shape):
return value
elif isinstance(value, dict):
try:
_type = value.pop("_type")
except KeyError:
raise ValueError(
"dict initialization requires _type"
) from None
try:
shape_cls = _shape_types[_type]
except KeyError:
raise ValueError(f"unknown Shape type '{_type}'") from None
return shape_cls(**value)
else:
raise ValueError("invalid type for union values")
@validator("union", pre=True)
def _validate_union(cls, value: Any) -> Sequence[Dict[str, Any]]:
if isinstance(value, dict):
return list(cls._flatten_union_dict(value))
if not isinstance(value, Sequence):
raise TypeError("must be dict or sequence of dicts")
return value

@classmethod
def _flatten_union_dict(cls, nested: Dict[str, Any], keyname: str = "kind"
) -> Iterator[Dict[str, Any]]:
for key, value in nested.items():
keydict = {keyname: key} if keyname else {}
if isinstance(value, list):
yield from ({**x, **keydict} for x in value)
else:
yield {**value, **keydict}
""",
),
"OME/StructuredAnnotations": Override(
Expand Down Expand Up @@ -358,6 +320,42 @@ def dict(self, **k: Any) -> Dict[str, Any]:
"BinData": ClassOverride(base_type="object", fields="value: str"),
"Map": ClassOverride(fields_suppress={"K"}),
"M": ClassOverride(base_type="object", fields="value: str"),
"LightEmittingDiode": ClassOverride(
fields='kind: Literal["light_emitting_diode"] = "light_emitting_diode"',
imports="from typing_extensions import Literal",
),
"Laser": ClassOverride(
fields='kind: Literal["laser"] = "laser"',
imports="from typing_extensions import Literal",
),
"Arc": ClassOverride(
fields='kind: Literal["arc"] = "arc"',
imports="from typing_extensions import Literal",
),
"Filament": ClassOverride(
fields='kind: Literal["filament"] = "filament"',
imports="from typing_extensions import Literal",
),
"GenericExcitationSource": ClassOverride(
fields='kind: Literal["generic_excitation_source"] = "generic_excitation_source"',
imports="from typing_extensions import Literal",
),
"Label": ClassOverride(
fields='kind: Literal["label"] = "label"',
imports="from typing_extensions import Literal",
),
"Point": ClassOverride(
fields='kind: Literal["point"] = "point"',
imports="from typing_extensions import Literal",
),
"Mask": ClassOverride(
fields='kind: Literal["mask"] = "mask"',
imports="from typing_extensions import Literal",
),
"Rectangle": ClassOverride(
fields='kind: Literal["rectangle"] = "rectangle"',
imports="from typing_extensions import Literal",
),
}


Expand Down Expand Up @@ -491,6 +489,41 @@ def make_dataclass(component: Union[XsdComponent, XsdType]) -> List[str]:
return lines


def make_abstract_class(component: XsdComponent) -> List[str]:
# FIXME: ? this might be a bit of an OME-schema-specific hack
# this seems to be how abstract is used in the OME schema
for e in component.iter_components():
if e != component:
raise NotImplementedError(
"Don't yet know how to handle abstract class with sub-components"
)

subs = [
el
for el in component.schema.elements.values()
if el.substitution_group == component.name
]

if not subs:
raise NotImplementedError(
"Don't know how to handle abstract class without substitutionGroups"
)

for el in subs:
if not el.type.is_extension() and el.type.base_type == component.type:
raise NotImplementedError(
"Expected all items in substitution group to extend "
f"the type {component.type} of Abstract element {component}"
)

sub_names = [el.local_name for el in subs]
lines = ["from typing import Union", *[local_import(n) for n in sub_names]]
lines += [local_import(component.type.local_name)]
lines += [f"{component.local_name} = {component.type.local_name}", ""]
lines += [f"{component.local_name}Type = Union[{', '.join(sub_names)}]"]
return lines


def make_enum(component: XsdComponent) -> List[str]:
name = component.local_name
_type = component.type if hasattr(component, "type") else component
Expand Down Expand Up @@ -952,46 +985,11 @@ def _simple_class(self) -> List[str]:
lines = ["import re", ""] + lines
return lines

def _abstract_class(self) -> List[str]:
# FIXME: ? this might be a bit of an OME-schema-specific hack
# this seems to be how abstract is used in the OME schema
for e in self.elem.iter_components():
if e != self.elem:
raise NotImplementedError(
"Don't yet know how to handle abstract class with sub-components"
)

subs = [
el
for el in self.elem.schema.elements.values()
if el.substitution_group == self.elem.name
]

if not subs:
raise NotImplementedError(
"Don't know how to handle abstract class without substitutionGroups"
)

for el in subs:
if not el.type.is_extension() and el.type.base_type == self.elem.type:
raise NotImplementedError(
"Expected all items in substitution group to extend "
f"the type {self.elem.type} of Abstract element {self.elem}"
)

sub_names = [el.local_name for el in subs]
lines = ["from typing import Union"]
lines.extend([local_import(n) for n in sub_names])
lines += [local_import(self.elem.type.local_name)]
lines += [f"{self.elem.local_name} = {self.elem.type.local_name}", ""]
lines += [f"{self.elem.local_name}Type = Union[{', '.join(sub_names)}]"]
return lines

def lines(self) -> str:
if not self.is_complex:
lines = self._simple_class()
elif self.elem.abstract:
lines = self._abstract_class()
lines = make_abstract_class(self.elem)
else:
lines = make_dataclass(self.elem)
return "\n".join(lines)
Expand Down
6 changes: 6 additions & 0 deletions src/ome_types/_base_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,3 +177,9 @@ def __getstate__(self: Any) -> Dict[str, Any]:
state = super().__getstate__()
state["__private_attribute_values__"].pop("_ref", None)
return state

@classmethod
def snake_name(cls) -> str:
from .model import _camel_to_snake

return _camel_to_snake[cls.__name__]
38 changes: 0 additions & 38 deletions src/ome_types/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,44 +112,6 @@ def element_decode(self, data, xsd_element, xsd_type=None, level=0): # type: ig
elif xsd_element.local_name == "BinData":
if result["length"] == 0 and "value" not in result:
result["value"] = ""
elif xsd_element.local_name == "Instrument":
light_sources = []
for _type in (
"laser",
"arc",
"filament",
"light_emitting_diode",
"generic_excitation_source",
):
if _type in result:
values = result.pop(_type)
if isinstance(values, dict):
values = [values]
for v in values:
v["_type"] = _type
light_sources.extend(values)
if light_sources:
result["light_source_group"] = light_sources
elif xsd_element.local_name == "Union":
shapes = []
for _type in (
"point",
"line",
"rectangle",
"ellipse",
"polyline",
"polygon",
"mask",
"label",
):
if _type in result:
values = result.pop(_type)
if isinstance(values, dict):
values = [values]
for v in values:
v["_type"] = _type
shapes.extend(values)
result = shapes
elif xsd_element.local_name == "StructuredAnnotations":
annotations = []
for _type in (
Expand Down