Skip to content

Commit

Permalink
feat: add transformations to from_xml (#208)
Browse files Browse the repository at this point in the history
* feat: add transformations to from_xml

* docs

* update tests

* skip on 3.7
  • Loading branch information
tlambert03 authored Jul 16, 2023
1 parent 46bae0c commit 51b6f2e
Show file tree
Hide file tree
Showing 6 changed files with 529 additions and 30 deletions.
79 changes: 51 additions & 28 deletions src/ome_types/_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from functools import lru_cache
from pathlib import Path
from struct import Struct
from typing import TYPE_CHECKING, cast, overload
from typing import TYPE_CHECKING, Callable, Iterable, cast, overload

from pydantic import BaseModel
from xsdata.formats.dataclass.parsers.config import ParserConfig
Expand All @@ -28,6 +28,7 @@

if TYPE_CHECKING:
from typing import Any, BinaryIO, ContextManager, TypedDict
from xml.etree import ElementTree

import xmlschema
from lxml.etree import _XSLTResultTree
Expand All @@ -38,9 +39,12 @@
from ome_types.model import OME
from xsdata_pydantic_basemodel.bindings import XmlContext

AnyElement = ET._Element | ElementTree.Element
AnyElementTree = ElementTree.ElementTree | ET._ElementTree
ElementOrTree = AnyElement | AnyElementTree
TransformationCallable = Callable[[AnyElementTree], AnyElementTree]
XMLSource = Path | str | bytes | BinaryIO
FileLike = str | io.BufferedIOBase
ElementOrTree = ET._Element | ET._ElementTree

class ParserKwargs(TypedDict, total=False):
config: ParserConfig
Expand All @@ -62,7 +66,8 @@ def from_xml(
validate: bool | None = None,
parser: Any = None,
parser_kwargs: ParserKwargs | None = None,
warn_on_transform: bool = False,
transformations: Iterable[TransformationCallable] = (),
warn_on_schema_update: bool = False,
) -> OME: # Not totally true, see note below
"""Generate an OME object from an XML document.
Expand All @@ -87,7 +92,12 @@ def from_xml(
parser_kwargs : ParserKwargs | None
Passed to the XmlParser constructor. If None, a default parser
will be used.
warn_on_transform : bool
transformations: Iterable[TransformationCallable]
A sequence of functions that take an ElementTree and return an ElementTree.
These will be applied sequentially to the XML document before parsing.
Can be used to apply custom transformations or fixes to the XML document
before parsing.
warn_on_schema_update : bool
Whether to warn if a transformation was applied to bring the document to
OME-2016-06.
Expand All @@ -106,12 +116,19 @@ def from_xml(
)

if validate:
xml_2016 = validate_xml(source, warn_on_transform=warn_on_transform)
xml_2016 = validate_xml(source, warn_on_schema_update=warn_on_schema_update)
else:
xml_2016 = ensure_2016(
source, warn_on_transform=warn_on_transform, as_tree=True
source, warn_on_schema_update=warn_on_schema_update, as_tree=True
)

for transform in transformations:
tree_out = transform(xml_2016)
if tree_out is not None:
xml_2016 = tree_out
else:
warnings.warn("Transformation returned None, skipping", stacklevel=2)

OME_type = _get_root_ome_type(xml_2016)
parser = XmlParser(**(parser_kwargs or {}))
return parser.parse(xml_2016, OME_type)
Expand Down Expand Up @@ -313,17 +330,20 @@ class ValidationError(ValueError):


def validate_xml(
xml: XMLSource, schema: Path | str | None = None, *, warn_on_transform: bool = True
) -> ET._ElementTree:
xml: XMLSource,
schema: Path | str | None = None,
*,
warn_on_schema_update: bool = True,
) -> AnyElementTree:
"""Validate XML against an XML Schema.
By default, will validate against the OME 2016-06 schema.
"""
with suppress(ImportError):
return validate_xml_with_lxml(xml, schema, warn_on_transform)
return validate_xml_with_lxml(xml, schema, warn_on_schema_update)

with suppress(ImportError): # pragma: no cover
return validate_xml_with_xmlschema(xml, schema, warn_on_transform)
return validate_xml_with_xmlschema(xml, schema, warn_on_schema_update)

raise ImportError( # pragma: no cover
"Validation requires either `lxml` or `xmlschema`. "
Expand All @@ -332,15 +352,15 @@ def validate_xml(


def validate_xml_with_lxml(
xml: XMLSource, schema: Path | str | None = None, warn_on_transform: bool = True
) -> ET._ElementTree:
xml: XMLSource, schema: Path | str | None = None, warn_on_schema_update: bool = True
) -> AnyElementTree:
"""Validate XML against an XML Schema using lxml."""
from lxml import etree

tree = ensure_2016(xml, warn_on_transform=warn_on_transform, as_tree=True)
tree = ensure_2016(xml, warn_on_schema_update=warn_on_schema_update, as_tree=True)
xmlschema = etree.XMLSchema(etree.parse(schema or OME_2016_06_XSD))

if not xmlschema.validate(tree):
if not xmlschema.validate(cast("ET._ElementTree", tree)):
msg = f"Validation of {str(xml)[:20]!r} failed:"
for error in xmlschema.error_log:
msg += f"\n - line {error.line}: {error.message}"
Expand All @@ -349,12 +369,12 @@ def validate_xml_with_lxml(


def validate_xml_with_xmlschema(
xml: XMLSource, schema: Path | str | None = None, warn_on_transform: bool = True
) -> ET._ElementTree:
xml: XMLSource, schema: Path | str | None = None, warn_on_schema_update: bool = True
) -> AnyElementTree:
"""Validate XML against an XML Schema using xmlschema."""
from xmlschema.exceptions import XMLSchemaException

tree = ensure_2016(xml, warn_on_transform=warn_on_transform, as_tree=True)
tree = ensure_2016(xml, warn_on_schema_update=warn_on_schema_update, as_tree=True)
xmlschema = _get_XMLSchema(schema or OME_2016_06_XSD)
try:
xmlschema.validate(tree) # type: ignore[arg-type]
Expand Down Expand Up @@ -386,21 +406,24 @@ def _get_XMLSchema(schema: Path | str) -> xmlschema.XMLSchema:

@overload
def ensure_2016(
source: XMLSource, *, warn_on_transform: bool = ..., as_tree: Literal[True]
) -> ET._ElementTree:
source: XMLSource, *, warn_on_schema_update: bool = ..., as_tree: Literal[True]
) -> AnyElementTree:
...


@overload
def ensure_2016(
source: XMLSource, *, warn_on_transform: bool = ..., as_tree: Literal[False] = ...
source: XMLSource,
*,
warn_on_schema_update: bool = ...,
as_tree: Literal[False] = ...,
) -> FileLike:
...


def ensure_2016(
source: XMLSource, *, warn_on_transform: bool = False, as_tree: bool = False
) -> FileLike | ET._ElementTree:
source: XMLSource, *, warn_on_schema_update: bool = False, as_tree: bool = False
) -> FileLike | AnyElementTree:
"""Ensure source is OME-2016-06 XML.
If the source is not OME-2016-06 XML, it will be transformed sequentially using
Expand All @@ -411,15 +434,15 @@ def ensure_2016(
----------
source : Path | str | bytes | io.BytesIO
Path to an XML file, string or bytes containing XML, or a file-like object.
warn_on_transform : bool
warn_on_schema_update : bool
Whether to warn if a transformation was applied to bring the document to
OME-2016-06.
as_tree : bool
Whether to return an ElementTree or a FileLike object.
Returns
-------
FileLike | ET._ElementTree
FileLike | AnyElementTree
If `as_tree` is `True`, an ElementTree, otherwise a FileLike object representing
transformed OME 2016 XML.
Expand Down Expand Up @@ -456,7 +479,7 @@ def ensure_2016(
while ns in TRANSFORMS:
tree = _apply_xslt(tree, TRANSFORMS[ns])
ns = _get_ns_elem(tree)
if warn_on_transform:
if warn_on_schema_update:
warnings.warn(
f"Transformed source from {ns_in!r} to {OME_2016_06_URI!r}",
stacklevel=2,
Expand Down Expand Up @@ -510,7 +533,7 @@ def _normalize(source: XMLSource) -> FileLike:
raise TypeError(f"Unsupported source type {type(source)!r}")


def _apply_xslt(root: ElementOrTree, xslt_path: str | Path) -> _XSLTResultTree:
def _apply_xslt(root: ET._ElementTree, xslt_path: str | Path) -> _XSLTResultTree:
"""Apply an XSLT transform to an element or element tree."""
try:
from lxml import etree
Expand All @@ -530,7 +553,7 @@ def _apply_xslt(root: ElementOrTree, xslt_path: str | Path) -> _XSLTResultTree:
# ------------------------


def _get_ns_elem(elem: ET._Element | ET._ElementTree) -> str:
def _get_ns_elem(elem: ET._Element | AnyElementTree) -> str:
"""Get namespace from an element or element tree."""
root = elem.getroot() if hasattr(elem, "getroot") else elem
# return root.nsmap[root.prefix] this only works for lxml
Expand All @@ -545,7 +568,7 @@ def _get_ns_file(source: FileLike) -> str:
return _get_ns_elem(root) # type: ignore[arg-type]


def _get_root_ome_type(xml: FileLike | ET._ElementTree) -> type[OMEType]:
def _get_root_ome_type(xml: FileLike | AnyElementTree) -> type[OMEType]:
"""Resolve a ome_types.model class for the root element of an OME XML document."""
from ome_types import model

Expand Down
45 changes: 45 additions & 0 deletions src/ome_types/etree_fixes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from __future__ import annotations

from typing import TYPE_CHECKING

from ome_types._conversion import OME_2016_06_URI

if TYPE_CHECKING:
from ome_types._conversion import AnyElementTree


NSMAP = {"": OME_2016_06_URI, "ome": OME_2016_06_URI}


# See note before using...
def fix_micro_manager_instrument(tree: AnyElementTree) -> AnyElementTree:
"""Fix MicroManager Instrument and Detector IDs and References.
Some versions of OME-XML produced by MicroManager have invalid IDs (and references)
for Instruments and Detectors. This function fixes those IDs and references.
NOTE: as of v0.4.0, bad IDs and references are caught during ID validation anyway,
so this is mostly an example of a fix function, and could be used to prevent
the warning from being raised.
"""
for i_idx, instrument in enumerate(tree.findall("Instrument", NSMAP)):
old_id = instrument.get("ID")
if old_id.startswith("Microscope"):
new_id = f"Instrument:{i_idx}"
instrument.set("ID", new_id)
for ref in tree.findall(f".//InstrumentRef[@ID='{old_id}']", NSMAP):
ref.set("ID", new_id)

for d_idx, detector in enumerate(instrument.findall(".//Detector", NSMAP)):
old_id = detector.get("ID")
if not old_id.startswith("Detector:"):
new_id = f"Detector:{old_id if old_id.isdigit() else d_idx}"
detector.set("ID", new_id)
for ref in tree.findall(f".//DetectorSettings[@ID='{old_id}']", NSMAP):
ref.set("ID", new_id)

return tree


ALL_FIXES = [fix_micro_manager_instrument]
__all__ = ["ALL_FIXES", "fix_micro_manager_instrument"]
6 changes: 5 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@

DATA = Path(__file__).parent / "data"
ALL_XML = set(DATA.glob("*.ome.xml"))
INVALID = {DATA / "invalid_xml_annotation.ome.xml", DATA / "bad.ome.xml"}
INVALID = {
DATA / "invalid_xml_annotation.ome.xml",
DATA / "bad.ome.xml",
DATA / "MMStack.ome.xml",
}
OLD_SCHEMA = {DATA / "seq0000xy01c1.ome.xml", DATA / "2008_instrument.ome.xml"}
WITH_XML_ANNOTATIONS = {
DATA / "ome_ns.ome.xml",
Expand Down
Loading

0 comments on commit 51b6f2e

Please sign in to comment.