Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add transformations to from_xml #208

Merged
merged 5 commits into from
Jul 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 51 additions & 28 deletions src/ome_types/_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from functools import lru_cache
from pathlib import Path
from struct import Struct
from typing import TYPE_CHECKING, cast, overload
from typing import TYPE_CHECKING, Callable, Iterable, cast, overload

from pydantic import BaseModel
from xsdata.formats.dataclass.parsers.config import ParserConfig
Expand All @@ -28,6 +28,7 @@

if TYPE_CHECKING:
from typing import Any, BinaryIO, ContextManager, TypedDict
from xml.etree import ElementTree

import xmlschema
from lxml.etree import _XSLTResultTree
Expand All @@ -38,9 +39,12 @@
from ome_types.model import OME
from xsdata_pydantic_basemodel.bindings import XmlContext

AnyElement = ET._Element | ElementTree.Element
AnyElementTree = ElementTree.ElementTree | ET._ElementTree
ElementOrTree = AnyElement | AnyElementTree
TransformationCallable = Callable[[AnyElementTree], AnyElementTree]
XMLSource = Path | str | bytes | BinaryIO
FileLike = str | io.BufferedIOBase
ElementOrTree = ET._Element | ET._ElementTree

class ParserKwargs(TypedDict, total=False):
config: ParserConfig
Expand All @@ -62,7 +66,8 @@ def from_xml(
validate: bool | None = None,
parser: Any = None,
parser_kwargs: ParserKwargs | None = None,
warn_on_transform: bool = False,
transformations: Iterable[TransformationCallable] = (),
warn_on_schema_update: bool = False,
) -> OME: # Not totally true, see note below
"""Generate an OME object from an XML document.

Expand All @@ -87,7 +92,12 @@ def from_xml(
parser_kwargs : ParserKwargs | None
Passed to the XmlParser constructor. If None, a default parser
will be used.
warn_on_transform : bool
transformations: Iterable[TransformationCallable]
A sequence of functions that take an ElementTree and return an ElementTree.
These will be applied sequentially to the XML document before parsing.
Can be used to apply custom transformations or fixes to the XML document
before parsing.
warn_on_schema_update : bool
Whether to warn if a transformation was applied to bring the document to
OME-2016-06.

Expand All @@ -106,12 +116,19 @@ def from_xml(
)

if validate:
xml_2016 = validate_xml(source, warn_on_transform=warn_on_transform)
xml_2016 = validate_xml(source, warn_on_schema_update=warn_on_schema_update)
else:
xml_2016 = ensure_2016(
source, warn_on_transform=warn_on_transform, as_tree=True
source, warn_on_schema_update=warn_on_schema_update, as_tree=True
)

for transform in transformations:
tree_out = transform(xml_2016)
if tree_out is not None:
xml_2016 = tree_out
else:
warnings.warn("Transformation returned None, skipping", stacklevel=2)

OME_type = _get_root_ome_type(xml_2016)
parser = XmlParser(**(parser_kwargs or {}))
return parser.parse(xml_2016, OME_type)
Expand Down Expand Up @@ -313,17 +330,20 @@ class ValidationError(ValueError):


def validate_xml(
xml: XMLSource, schema: Path | str | None = None, *, warn_on_transform: bool = True
) -> ET._ElementTree:
xml: XMLSource,
schema: Path | str | None = None,
*,
warn_on_schema_update: bool = True,
) -> AnyElementTree:
"""Validate XML against an XML Schema.

By default, will validate against the OME 2016-06 schema.
"""
with suppress(ImportError):
return validate_xml_with_lxml(xml, schema, warn_on_transform)
return validate_xml_with_lxml(xml, schema, warn_on_schema_update)

with suppress(ImportError): # pragma: no cover
return validate_xml_with_xmlschema(xml, schema, warn_on_transform)
return validate_xml_with_xmlschema(xml, schema, warn_on_schema_update)

raise ImportError( # pragma: no cover
"Validation requires either `lxml` or `xmlschema`. "
Expand All @@ -332,15 +352,15 @@ def validate_xml(


def validate_xml_with_lxml(
xml: XMLSource, schema: Path | str | None = None, warn_on_transform: bool = True
) -> ET._ElementTree:
xml: XMLSource, schema: Path | str | None = None, warn_on_schema_update: bool = True
) -> AnyElementTree:
"""Validate XML against an XML Schema using lxml."""
from lxml import etree

tree = ensure_2016(xml, warn_on_transform=warn_on_transform, as_tree=True)
tree = ensure_2016(xml, warn_on_schema_update=warn_on_schema_update, as_tree=True)
xmlschema = etree.XMLSchema(etree.parse(schema or OME_2016_06_XSD))

if not xmlschema.validate(tree):
if not xmlschema.validate(cast("ET._ElementTree", tree)):
msg = f"Validation of {str(xml)[:20]!r} failed:"
for error in xmlschema.error_log:
msg += f"\n - line {error.line}: {error.message}"
Expand All @@ -349,12 +369,12 @@ def validate_xml_with_lxml(


def validate_xml_with_xmlschema(
xml: XMLSource, schema: Path | str | None = None, warn_on_transform: bool = True
) -> ET._ElementTree:
xml: XMLSource, schema: Path | str | None = None, warn_on_schema_update: bool = True
) -> AnyElementTree:
"""Validate XML against an XML Schema using xmlschema."""
from xmlschema.exceptions import XMLSchemaException

tree = ensure_2016(xml, warn_on_transform=warn_on_transform, as_tree=True)
tree = ensure_2016(xml, warn_on_schema_update=warn_on_schema_update, as_tree=True)
xmlschema = _get_XMLSchema(schema or OME_2016_06_XSD)
try:
xmlschema.validate(tree) # type: ignore[arg-type]
Expand Down Expand Up @@ -386,21 +406,24 @@ def _get_XMLSchema(schema: Path | str) -> xmlschema.XMLSchema:

@overload
def ensure_2016(
source: XMLSource, *, warn_on_transform: bool = ..., as_tree: Literal[True]
) -> ET._ElementTree:
source: XMLSource, *, warn_on_schema_update: bool = ..., as_tree: Literal[True]
) -> AnyElementTree:
...


@overload
def ensure_2016(
source: XMLSource, *, warn_on_transform: bool = ..., as_tree: Literal[False] = ...
source: XMLSource,
*,
warn_on_schema_update: bool = ...,
as_tree: Literal[False] = ...,
) -> FileLike:
...


def ensure_2016(
source: XMLSource, *, warn_on_transform: bool = False, as_tree: bool = False
) -> FileLike | ET._ElementTree:
source: XMLSource, *, warn_on_schema_update: bool = False, as_tree: bool = False
) -> FileLike | AnyElementTree:
"""Ensure source is OME-2016-06 XML.

If the source is not OME-2016-06 XML, it will be transformed sequentially using
Expand All @@ -411,15 +434,15 @@ def ensure_2016(
----------
source : Path | str | bytes | io.BytesIO
Path to an XML file, string or bytes containing XML, or a file-like object.
warn_on_transform : bool
warn_on_schema_update : bool
Whether to warn if a transformation was applied to bring the document to
OME-2016-06.
as_tree : bool
Whether to return an ElementTree or a FileLike object.

Returns
-------
FileLike | ET._ElementTree
FileLike | AnyElementTree
If `as_tree` is `True`, an ElementTree, otherwise a FileLike object representing
transformed OME 2016 XML.

Expand Down Expand Up @@ -456,7 +479,7 @@ def ensure_2016(
while ns in TRANSFORMS:
tree = _apply_xslt(tree, TRANSFORMS[ns])
ns = _get_ns_elem(tree)
if warn_on_transform:
if warn_on_schema_update:
warnings.warn(
f"Transformed source from {ns_in!r} to {OME_2016_06_URI!r}",
stacklevel=2,
Expand Down Expand Up @@ -510,7 +533,7 @@ def _normalize(source: XMLSource) -> FileLike:
raise TypeError(f"Unsupported source type {type(source)!r}")


def _apply_xslt(root: ElementOrTree, xslt_path: str | Path) -> _XSLTResultTree:
def _apply_xslt(root: ET._ElementTree, xslt_path: str | Path) -> _XSLTResultTree:
"""Apply an XSLT transform to an element or element tree."""
try:
from lxml import etree
Expand All @@ -530,7 +553,7 @@ def _apply_xslt(root: ElementOrTree, xslt_path: str | Path) -> _XSLTResultTree:
# ------------------------


def _get_ns_elem(elem: ET._Element | ET._ElementTree) -> str:
def _get_ns_elem(elem: ET._Element | AnyElementTree) -> str:
"""Get namespace from an element or element tree."""
root = elem.getroot() if hasattr(elem, "getroot") else elem
# return root.nsmap[root.prefix] this only works for lxml
Expand All @@ -545,7 +568,7 @@ def _get_ns_file(source: FileLike) -> str:
return _get_ns_elem(root) # type: ignore[arg-type]


def _get_root_ome_type(xml: FileLike | ET._ElementTree) -> type[OMEType]:
def _get_root_ome_type(xml: FileLike | AnyElementTree) -> type[OMEType]:
"""Resolve a ome_types.model class for the root element of an OME XML document."""
from ome_types import model

Expand Down
45 changes: 45 additions & 0 deletions src/ome_types/etree_fixes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from __future__ import annotations

from typing import TYPE_CHECKING

from ome_types._conversion import OME_2016_06_URI

if TYPE_CHECKING:
from ome_types._conversion import AnyElementTree


NSMAP = {"": OME_2016_06_URI, "ome": OME_2016_06_URI}


# See note before using...
def fix_micro_manager_instrument(tree: AnyElementTree) -> AnyElementTree:
"""Fix MicroManager Instrument and Detector IDs and References.

Some versions of OME-XML produced by MicroManager have invalid IDs (and references)
for Instruments and Detectors. This function fixes those IDs and references.

NOTE: as of v0.4.0, bad IDs and references are caught during ID validation anyway,
so this is mostly an example of a fix function, and could be used to prevent
the warning from being raised.
"""
for i_idx, instrument in enumerate(tree.findall("Instrument", NSMAP)):
old_id = instrument.get("ID")
if old_id.startswith("Microscope"):
new_id = f"Instrument:{i_idx}"
instrument.set("ID", new_id)
for ref in tree.findall(f".//InstrumentRef[@ID='{old_id}']", NSMAP):
ref.set("ID", new_id)

for d_idx, detector in enumerate(instrument.findall(".//Detector", NSMAP)):
old_id = detector.get("ID")
if not old_id.startswith("Detector:"):
new_id = f"Detector:{old_id if old_id.isdigit() else d_idx}"
detector.set("ID", new_id)
for ref in tree.findall(f".//DetectorSettings[@ID='{old_id}']", NSMAP):
ref.set("ID", new_id)

return tree


ALL_FIXES = [fix_micro_manager_instrument]
__all__ = ["ALL_FIXES", "fix_micro_manager_instrument"]
6 changes: 5 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@

DATA = Path(__file__).parent / "data"
ALL_XML = set(DATA.glob("*.ome.xml"))
INVALID = {DATA / "invalid_xml_annotation.ome.xml", DATA / "bad.ome.xml"}
INVALID = {
DATA / "invalid_xml_annotation.ome.xml",
DATA / "bad.ome.xml",
DATA / "MMStack.ome.xml",
}
OLD_SCHEMA = {DATA / "seq0000xy01c1.ome.xml", DATA / "2008_instrument.ome.xml"}
WITH_XML_ANNOTATIONS = {
DATA / "ome_ns.ome.xml",
Expand Down
Loading