diff --git a/.github/scripts/download_aics_test_data.py b/.github/scripts/download_aics_test_data.py deleted file mode 100644 index 854707db..00000000 --- a/.github/scripts/download_aics_test_data.py +++ /dev/null @@ -1,25 +0,0 @@ -from pathlib import Path - -from quilt3 import Package - - -def download_test_resources() -> None: - root = Path(__file__).parent.parent.parent - resources = (root / "aicsimageio" / "aicsimageio" / "tests" / "resources").resolve() - - # Get the specific hash for test resources - with open(root / "aicsimageio" / "scripts" / "TEST_RESOURCES_HASH.txt") as f: - top_hash = f.readline().strip() - - # Download test resources - resources.mkdir(exist_ok=True) - package = Package.browse( - "aicsimageio/test_resources", - "s3://aics-modeling-packages-test-resources", - top_hash=top_hash, - ) - package["resources"].fetch(resources) - - -if __name__ == "__main__": - download_test_resources() diff --git a/.github/workflows/build.yml b/.github/workflows/built_branch.yml similarity index 58% rename from .github/workflows/build.yml rename to .github/workflows/built_branch.yml index 7a6c3a81..97a8dff5 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/built_branch.yml @@ -1,14 +1,9 @@ -name: build branch +name: update built branch on: push: branches: - "main" - schedule: - # - # https://pubs.opengroup.org/onlinepubs/9699919799/utilities/crontab.html#tag_20_25_07 - # Run every Monday at 18:00:00 UTC (Monday at 10:00:00 PST) - - cron: '0 18 * * 1' jobs: build: @@ -17,24 +12,27 @@ jobs: - name: Checkout built branch uses: actions/checkout@v3 with: - ref: 'built' + ref: "built" + fetch-depth: 0 + - uses: actions/setup-python@v4 with: - python-version: 3.9 + python-version: "3.x" + - name: Clone main run: git clone https://github.com/tlambert03/ome-types - name: Build run: | - python -m pip install --upgrade pip - pip install build + python -m pip install --upgrade pip build python -m build ome-types --wheel rm -rf ome_types unzip ome-types/dist/ome_types-* + - name: Commit if: github.event_name == 'push' run: | git config user.name "Talley Lambert" git config user.email "talley.lambert@gmail.com" git add ome_types - git commit -m 'Commit from GitHub Actions (build.yml)' + git commit -m 'Commit from GitHub Actions (built_branch.yml)' git push --set-upstream origin built diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml deleted file mode 100644 index b2055dc9..00000000 --- a/.github/workflows/release.yml +++ /dev/null @@ -1,44 +0,0 @@ -name: release - -on: - push: - # Sequence of patterns matched against refs/tags - tags: - - 'v*' # Push events to matching v*, i.e. v1.0, v20.15.10 - -jobs: - build: - name: Create Release - runs-on: ubuntu-latest - # if: github.repository == 'tlambert03/ome-types' - steps: - - uses: actions/checkout@v3 - - uses: actions/setup-python@v4 - with: - python-version: 3.9 - - - name: Install Dependencies - run: | - python -m pip install --upgrade pip - pip install -U setuptools setuptools_scm wheel check-manifest - pip install -e .[autogen] - - - name: Build Distribution - run: | - TAG="${GITHUB_REF/refs\/tags\/v/}" - echo "tag=${TAG}" >> $GITHUB_ENV - check-manifest - python -m build . - - - name: Create Release - id: create_release - uses: softprops/action-gh-release@v1 - with: - generate_release_notes: true - files: 'dist/*' - - - name: Publish PyPI Package - uses: pypa/gh-action-pypi-publish@master - with: - user: __token__ - password: ${{ secrets.TWINE_API_KEY }} diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index d89e5f1e..bd572a80 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -4,9 +4,14 @@ on: push: branches: - "main" + - "v2" + tags: + - "v*" # Push events to matching v*, i.e. v1.0, v20.15.10 pull_request: - branches: - - "main" + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true jobs: check_manifest: @@ -28,16 +33,11 @@ jobs: strategy: fail-fast: false matrix: - python-version: ['3.8', '3.9', '3.10'] + python-version: ["3.8", "3.9", "3.10", "3.11"] platform: [ubuntu-latest, macos-latest, windows-latest] include: - - platform: ubuntu-latest - python-version: '3.7' - - platform: macos-latest - python-version: '3.11' - - platform: ubuntu-latest - python-version: '3.11' - # skipping windows 3.11 until lxml has wheels + - python-version: "3.7" + platform: ubuntu-latest steps: - uses: actions/checkout@v3 @@ -47,10 +47,65 @@ jobs: python-version: ${{ matrix.python-version }} - name: Install dependencies run: | - python -m pip install --upgrade pip - pip install tox tox-gh-actions - pip install -e .[autogen] - - name: Test with tox - run: tox - env: - PLATFORM: ${{ matrix.platform }} + python -m pip install -U pip + python -m pip install .[test,dev] + - name: Test + run: pytest --cov --cov-report=xml + + - name: retest withou lxml or xmlschema + if: matrix.platform == 'ubuntu-latest' + run: | + pip uninstall -y lxml xmlschema + pytest --cov --cov-report=xml --cov-append + + - uses: codecov/codecov-action@v2 + + test-types: + name: Typesafety + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - uses: actions/setup-python@v4 + with: + python-version: "3.11" + - name: Install dependencies + run: | + python -m pip install -U pip + python -m pip install pytest pytest-mypy-plugins + python -m pip install . + + - name: Test + run: pytest typesafety -v + + deploy: + name: Deploy + runs-on: ubuntu-latest + needs: [test, check_manifest] + if: success() && startsWith(github.ref, 'refs/tags/') && github.event_name != 'schedule' + + steps: + - uses: actions/checkout@v3 + with: + fetch-depth: 0 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: "3.x" + + - name: Build + run: | + pip install -U pip build + python -m build . + + - name: Publish PyPI Package + uses: pypa/gh-action-pypi-publish@master + with: + user: __token__ + password: ${{ secrets.TWINE_API_KEY }} + + - name: Create Release + uses: softprops/action-gh-release@v1 + with: + generate_release_notes: true + files: "dist/*" diff --git a/.github/workflows/test_aics.yml b/.github/workflows/test_aics.yml deleted file mode 100644 index 5fb5fc1b..00000000 --- a/.github/workflows/test_aics.yml +++ /dev/null @@ -1,50 +0,0 @@ -name: test aicsimageio - -on: - push: - branches: - - "main" - workflow_dispatch: - -jobs: - test: - name: test aicsimageio - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v3 - - name: Set up Python 3.9 - uses: actions/setup-python@v4 - with: - python-version: 3.9 - - - name: Install ome-types - run: | - python -m pip install --upgrade pip - pip install wheel - pip install .[test] - - - name: Clone aicsimageio - run: | - git clone --recurse-submodules -b main https://github.com/AllenCellModeling/aicsimageio.git - - - uses: actions/cache@v3 - id: cache - with: - path: aicsimageio/aicsimageio/tests/resources - key: ${{ hashFiles('aicsimageio/scripts/TEST_RESOURCES_HASH.txt') }} - - - name: Install Test Data - if: steps.cache.outputs.cache-hit != 'true' - run: | - pip install quilt3 - python .github/scripts/download_aics_test_data.py - - - name: Install aicsimageio - run: | - cd aicsimageio - pip install .[test] - - - name: Run Tests - run: | - cd aicsimageio - pytest aicsimageio/tests/readers/test_ome_tiff_reader.py -v -k "not REMOTE" diff --git a/.github/workflows/test_dependents.yml b/.github/workflows/test_dependents.yml new file mode 100644 index 00000000..fdd85653 --- /dev/null +++ b/.github/workflows/test_dependents.yml @@ -0,0 +1,154 @@ +name: test dependents + +on: + push: + branches: + - "main" + pull_request: + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +jobs: + test-aicsimageio: + name: test aicsimageio + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + with: + repository: AllenCellModeling/aicsimageio + submodules: true + - uses: actions/checkout@v3 + with: + path: ome-types + fetch-depth: 0 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: "3.11" + + - name: Install aicsimageio + run: | + python -m pip install --upgrade pip + python -m pip install .[test] + python -m pip install bioformats_jar + + - uses: actions/cache@v3 + id: cache + with: + path: aicsimageio/tests/resources + key: ${{ hashFiles('scripts/TEST_RESOURCES_HASH.txt') }} + + - name: Download Test Resources + if: steps.cache.outputs.cache-hit != 'true' + run: python scripts/download_test_resources.py --debug + + - name: Install ome-types + run: pip install . + working-directory: ome-types + + - name: Run Tests + run: | + pytest --color=yes -k "not test_known_errors_without_cleaning and not bad" \ + aicsimageio/tests/readers/test_ome_tiff_reader.py \ + aicsimageio/tests/writers/test_ome_tiff_writer.py \ + aicsimageio/tests/readers/extra_readers/test_bioformats_reader.py \ + aicsimageio/tests/readers/extra_readers/test_ome_zarr_reader.py + + test-paquo: + name: test paquo + runs-on: ubuntu-latest + env: + QUPATH_VERSION: 0.4.3 + steps: + - uses: actions/checkout@v3 + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: "3.10" + - name: Install dependencies + run: | + python -m pip install -U pip + python -m pip install .[test,dev] + python -m pip install paquo + + - name: Restore qupath cache + uses: actions/cache@v3 + env: + CACHE_NUMBER: 0 + with: + path: ./qupath/download + key: ${{ runner.os }}-qupath-v${{ env.CACHE_NUMBER }} + + - name: Install qupath and set PAQUO_QUPATH_DIR + shell: bash + run: | + python -c "import os; os.makedirs('qupath/download', exist_ok=True)" + python -c "import os; os.makedirs('qupath/apps', exist_ok=True)" + python -m paquo get_qupath --install-path ./qupath/apps --download-path ./qupath/download ${{ env.QUPATH_VERSION }} \ + | grep -v "^#" | sed "s/^/PAQUO_QUPATH_DIR=/" >> $GITHUB_ENV + + - name: Test with pytest + run: pytest tests/test_paquo.py + + test-nd2: + name: test nd2 + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + with: + repository: tlambert03/nd2 + fetch-depth: 0 + - uses: actions/checkout@v3 + with: + path: ome-types + fetch-depth: 0 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: "3.11" + + - name: Install nd2 + run: | + python -m pip install --upgrade pip + python -m pip install .[test] + + - uses: actions/cache@v3 + id: cache + with: + path: tests/data + key: ${{ hashFiles('scripts/download_samples.py') }} + + - name: Download Samples + if: steps.cache.outputs.cache-hit != 'true' + run: | + pip install requests + python scripts/download_samples.py + + - name: Install ome-types + run: pip install . + working-directory: ome-types + + - name: Run Tests + run: pytest --color=yes -v tests/test_ome.py + + test_omero_cli: + name: test omero-cli-transfer + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: "3.8" + - name: Install dependencies + run: | + python -m pip install -U pip + python -m pip install .[test,dev] + python -m pip install omero-cli-transfer --no-deps + + - name: Test + run: pytest tests/test_omero_cli.py -v diff --git a/.gitignore b/.gitignore index 0fb313fa..bc9f30dd 100644 --- a/.gitignore +++ b/.gitignore @@ -107,8 +107,9 @@ venv.bak/ .DS_Store _test_data -src/ome_types/model/ +src/ome_types/_autogenerated/ src/ome_types/_version.py docs/source/_autosummary .benchmarks/ _build/ +qupath/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 433bdff8..3bb9f58a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -29,3 +29,7 @@ repos: hooks: - id: mypy exclude: ^tests|^docs|_napari_plugin|widgets + additional_dependencies: + - pydantic<2 + - Pint + - types-lxml; python_version > '3.8' diff --git a/README.md b/README.md index d3e09562..de416754 100644 --- a/README.md +++ b/README.md @@ -304,6 +304,3 @@ python src/ome_autogen.py To run tests quickly, just install and run `pytest`. Note, however, that this requires that the `ome_types.model` module has already been built with `python src/ome_autogen.py`. - -Alternatively, you can install and run `tox` which will run tests and -code-quality checks in an isolated environment. diff --git a/docs/source/ome_types.model.rst b/docs/source/ome_types.model.rst index 35b106e5..b792bcfb 100644 --- a/docs/source/ome_types.model.rst +++ b/docs/source/ome_types.model.rst @@ -56,7 +56,6 @@ ome\_types.model LightEmittingDiode LightPath LightSource - LightSourceGroup LightSourceSettings Line ListAnnotation @@ -92,7 +91,6 @@ ome\_types.model Screen Settings Shape - ShapeGroup StageLabel StructuredAnnotations TagAnnotation diff --git a/hatch_build.py b/hatch_build.py index 122c836b..b1ffb95d 100644 --- a/hatch_build.py +++ b/hatch_build.py @@ -1,3 +1,6 @@ +import os +import sys + from hatchling.builders.hooks.plugin.interface import BuildHookInterface @@ -8,10 +11,11 @@ class CustomBuildHook(BuildHookInterface): def initialize(self, version: str, build_data: dict) -> None: """Init before the build process begins.""" - import sys + if os.getenv("SKIP_AUTOGEN"): + return sys.path.append("src") - import ome_autogen + import ome_autogen.main - ome_autogen.convert_schema() + ome_autogen.main.build_model(do_mypy=False) diff --git a/pyproject.toml b/pyproject.toml index 38ce00b3..c632090b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,26 +3,6 @@ requires = ["hatchling", "hatch-vcs"] build-backend = "hatchling.build" -# https://hatch.pypa.io/latest/config/metadata/ -[tool.hatch.version] -source = "vcs" - -[tool.hatch.build] -artifacts = ["src/ome_types/model"] - -[tool.hatch.build.targets.sdist] -include = ["src", "tests", "CHANGELOG.md"] - -[tool.hatch.build.hooks.custom] -# requirements to run the autogen script in hatch_build.py -dependencies = [ - "black", - "isort>=5.0", - "xmlschema==1.4.1", - "autoflake", - "numpydoc", -] - # https://peps.python.org/pep-0621/ [project] name = "ome-types" @@ -51,8 +31,9 @@ dynamic = ["version"] dependencies = [ "Pint >=0.15", "lxml >=4.8.0", - "pydantic[email] >=1.0, <2.0", - "xmlschema >=2.0.0", + "pydantic[email] <2.0", + "xsdata", + "importlib_metadata; python_version < '3.8'", ] [project.urls] @@ -66,53 +47,84 @@ ome-types = "ome_types._napari_plugin" # extras # https://peps.python.org/pep-0621/#dependencies-optional-dependencies [project.optional-dependencies] -autogen = ["autoflake", "black", "isort>=5.0", "numpydoc"] -docs = [ - "autoflake", +dev = [ "black", - "ipython", - "isort>=5.0", + "ruff", + "xsdata[cli]>=23.6", + "mypy", + "pre-commit", + "types-lxml; python_version >= '3.8'", +] +docs = [ "numpydoc", "pygments", "sphinx==5.3.0", "sphinx-rtd-theme==1.1.1", + "ipython", ] -test = ["pytest", "pytest-cov", "pytest-benchmark", "tox"] +test = ["pytest", "pytest-cov", "xmlschema", "pytest-mypy-plugins"] + +# https://hatch.pypa.io/latest/plugins/build-hook/custom/ +[tool.hatch.build.targets.wheel.hooks.custom] +# requirements to run the autogen script in hatch_build.py +require-runtime-dependencies = true +dependencies = ["black", "ruff", "xsdata[cli]>=23.6"] + +# https://hatch.pypa.io/latest/config/metadata/ +[tool.hatch] +version = { source = "vcs" } + +[tool.hatch.build.targets.wheel] +only-include = ["src/ome_types", "src/xsdata_pydantic_basemodel"] +sources = ["src"] +artifacts = ["src/ome_types/_autogenerated"] + +[tool.hatch.build.targets.sdist] +include = ["src", "CHANGELOG.md", "hatch_build.py"] +exclude = ["src/ome_types/_autogenerated"] -[tool.setuptools_scm] -write_to = "src/ome_types/_version.py" # https://github.com/charliermarsh/ruff [tool.ruff] line-length = 88 src = ["src", "tests"] target-version = "py37" -extend-select = [ +select = [ "E", # style errors "F", # flakes "D", # pydocstyle - "I001", # isort + "I", # isort "UP", # pyupgrade - # "N", # pep8-naming "S", # bandit - "C", # flake8-comprehensions + "C4", # flake8-comprehensions "B", # flake8-bugbear "A001", # flake8-builtins + "TID", # tidy + "TCH", # typechecking "RUF", # ruff-specific rules ] -extend-ignore = [ - "D100", # Missing docstring in public module - "D104", # Missing docstring in public package - "D107", # Missing docstring in __init__ - "D203", # 1 blank line required before class docstring - "D212", # Multi-line docstring summary should start at the first line - "D213", # Multi-line docstring summary should start at the second line - "D400", # First line should end with a period - "D413", # Missing blank line after last section - "D416", # Section name should end with a colon - "C901", # Function is too complex +ignore = [ + "D100", # Missing docstring in public module + "D101", # Missing docstring in public class + "D104", # Missing docstring in public package + "D106", # Missing docstring in public nested class + "D107", # Missing docstring in __init__ + "D203", # 1 blank line required before class docstring + "D205", # 1 blank line required between summary line and description + "D212", # Multi-line docstring summary should start at the first line + "D213", # Multi-line docstring summary should start at the second line + "D400", # First line should end with a period + "D404", # First word of the docstring should not be This + "D413", # Missing blank line after last section + "D416", # Section name should end with a colon + "C901", # Function is too complex + "RUF009", # Do not perform function calls in default arguments ] +exclude = ['src/_ome_autogen.py'] + +[tool.ruff.flake8-tidy-imports] +ban-relative-imports = "all" # Disallow all relative imports. [tool.ruff.per-file-ignores] "tests/*.py" = ["D", "S"] @@ -120,17 +132,17 @@ extend-ignore = [ ".github/*.py" = ["D"] "setup.py" = ["D"] "docs/**/*.py" = ["D"] +"src/xsdata_pydantic_basemodel/**/*.py" = ["D"] [tool.check-manifest] ignore = [ - "src/ome_types/model/*", "coverage.yml", ".pre-commit-config.yaml", ".github_changelog_generator", ".readthedocs.yml", "docs/**/*", - "setup.py", - "tox.ini", + "tests/**/*", + "typesafety/**/*", ] @@ -141,30 +153,21 @@ float_to_top = true skip_glob = ["*examples/*", "*vendored*"] [tool.black] -target-version = ['py37', 'py38'] -exclude = ''' -/( - \.git - | \.hg - | \.mypy_cache - | \.tox - | \.venv - | _build - | buck-out - | build - | dist - | docs -)/ -''' +target-version = ['py38'] # https://docs.pytest.org/en/6.2.x/customize.html [tool.pytest.ini_options] minversion = "6.0" -addopts = "--benchmark-disable" testpaths = ["tests"] +addopts = '--mypy-only-local-stub --color=yes' filterwarnings = [ "error", "ignore:Casting invalid AnnotationID:UserWarning", + # FIXME: i think this might be an xsdata issue? + "ignore::ResourceWarning", + "ignore:pkg_resources is deprecated", # paquo tests + "ignore:::paquo", # paquo tests + "ignore:the imp module is deprecated", # omero tests ] # https://mypy.readthedocs.io/en/stable/config_file.html @@ -173,7 +176,41 @@ files = "src/**/*/*.py" follow_imports = 'silent' strict_optional = true warn_redundant_casts = true +# warn_unused_ignores = true disallow_any_generics = false no_implicit_reexport = true ignore_missing_imports = true disallow_untyped_defs = true +plugins = "pydantic.mypy" + +[tool.pydantic-mypy] +init_forbid_extra = true +init_typed = false # allow parsing Any + +[[tool.mypy.overrides]] +module = ['ome_types._autogenerated.ome_2016_06.structured_annotations'] +# Definition of "__iter__" in base class "BaseModel" +# is incompatible with definition in base class "Sequence" +disable_error_code = "misc" + +# https://coverage.readthedocs.io/en/6.4/config.html +[tool.coverage.report] +exclude_lines = [ + "pragma: no cover", + "if TYPE_CHECKING:", + "@overload", + "except ImportError", + "\\.\\.\\.", + "raise NotImplementedError()", +] + +[tool.coverage.run] +source = ["src/ome_types", "src/ome_autogen"] + + +# Entry points -- REMOVE ONCE XSDATA-PYDANTIC-BASEMODEL IS SEPARATE +[project.entry-points."xsdata.plugins.class_types"] +xsdata_pydantic_basemodel = "xsdata_pydantic_basemodel.hooks.class_type" + +[project.entry-points."xsdata.plugins.cli"] +xsdata_pydantic_basemodel = "xsdata_pydantic_basemodel.hooks.cli" diff --git a/setup.py b/setup.py deleted file mode 100644 index b3742253..00000000 --- a/setup.py +++ /dev/null @@ -1,29 +0,0 @@ -import sys - -sys.stderr.write( - """ -=============================== -Unsupported installation method -=============================== -ome-types does not support installation with `python setup.py install`. -Please use `python -m pip install .` instead. -""" -) -sys.exit(1) - - -# The below code will never execute, however GitHub is particularly -# picky about where it finds Python packaging metadata. -# See: https://github.com/github/feedback/discussions/6456 -# -# To be removed once GitHub catches up. - -setup( # type: ignore # noqa - name="ome-types", - install_requires=[ - "Pint >=0.15", - "lxml >=4.8.0", - "pydantic[email] >=1.0", - "xmlschema >=1.4.1", - ], -) diff --git a/src/ome_autogen.py b/src/ome_autogen.py deleted file mode 100644 index 420f2a5c..00000000 --- a/src/ome_autogen.py +++ /dev/null @@ -1,1119 +0,0 @@ -from __future__ import annotations - -import builtins -import collections -import os -import re -import shutil -from dataclasses import dataclass, field -from itertools import chain -from pathlib import Path -from textwrap import dedent, indent, wrap -from typing import Any, DefaultDict, Generator, Iterable, Iterator - -import black -import isort.api -from autoflake import fix_code -from numpydoc.docscrape import NumpyDocString, Parameter -from xmlschema import XMLSchema -from xmlschema.validators import ( - XsdAnyAttribute, - XsdAnyElement, - XsdAtomicBuiltin, - XsdAttribute, - XsdComponent, - XsdElement, - XsdType, -) - -try: - # xmlschema ≥ v1.4.0 - from xmlschema import names as qnames -except ImportError: - from xmlschema import qnames - - -# Track all camel-to-snake and pluralization results so we can include them in the model -camel_snake_registry: dict[str, str] = {} -plural_registry: dict[tuple[str, str], str] = {} -LISTS: collections.defaultdict[str, set[str]] = DefaultDict(set) - -# FIXME: Work out a better way to implement these override hacks. - - -@dataclass -class Override: - type_: str - default: str | None = None - imports: str | None = None - body: str | None = None - - def __post_init__(self) -> None: - if self.imports: - self.imports = dedent(self.imports) - if self.body: - self.body = indent(dedent(self.body), " " * 4) - - -@dataclass -class ClassOverride: - base_type: str | None = None - imports: str | None = None - fields: str | None = None - fields_suppress: set[str] = field(default_factory=set) - body: str | None = None - - def __post_init__(self) -> None: - if self.imports: - self.imports = dedent(self.imports) - if self.fields: - self.fields = indent(dedent(self.fields), " " * 4) - if self.body: - self.body = indent(dedent(self.body), " " * 4) - - -# Maps XSD TypeName to Override configuration, used to control output for that type. -OVERRIDES = { - "MetadataOnly": Override(type_="bool", default="False"), - # FIXME: Type should be xml.etree.ElementTree.Element but isinstance checks - # with that class often mysteriously fail so the validator fails. - "XMLAnnotation/Value": Override(type_="Element"), - "BinData/Length": Override(type_="int"), - # FIXME: hard-coded subclass lists - "Instrument/LightSourceGroup": Override( - type_="List[LightSourceGroupType]", - default="Field(default_factory=list)", - imports=""" - from typing import Dict, Union, Any - from pydantic import root_validator - from .light_source_group import LightSourceGroupType - """, - body=""" - @root_validator(pre=True) - def _root(cls, value: Dict[str, Any]): - light_sources = {i.snake_name() for i in LightSourceGroupType.__args__} # type: ignore - lights = [] - for key in list(value): - kind = {"kind": key} - if key in light_sources: - val = value.pop(key) - if isinstance(val, dict): - lights.append({**val, **kind}) - elif isinstance(val, list): - lights.extend({**v, **kind} for v in val) - if lights: - value.setdefault("light_source_group", []) - value["light_source_group"].extend(lights) - return value - """, - ), - "ROI/Union": Override( - type_="List[ShapeGroupType]", - default="Field(default_factory=list)", - imports=""" - from typing import Dict, Union, Any, Sequence, Iterator - from pydantic import validator - - from .annotation_ref import AnnotationRef - from .shape_group import ShapeGroupType - from .simple_types import ROIID - """, - body=""" - @validator("union", pre=True) - def _validate_union(cls, value: Any) -> Sequence[Dict[str, Any]]: - if isinstance(value, dict): - return list(cls._flatten_union_dict(value)) - if not isinstance(value, Sequence): - raise TypeError("must be dict or sequence of dicts") - return value - - @classmethod - def _flatten_union_dict(cls, nested: Dict[str, Any], keyname: str = "kind" - ) -> Iterator[Dict[str, Any]]: - for key, value in nested.items(): - keydict = {keyname: key} if keyname else {} - if isinstance(value, list): - yield from ({**x, **keydict} for x in value) - else: - yield {**value, **keydict} - """, - ), - "OME/StructuredAnnotations": Override( - type_="List[Annotation]", - default="Field(default_factory=list)", - imports=""" - from typing import Dict, Union, Any - from pydantic import validator - from .annotation import Annotation - from .boolean_annotation import BooleanAnnotation - from .comment_annotation import CommentAnnotation - from .double_annotation import DoubleAnnotation - from .file_annotation import FileAnnotation - from .list_annotation import ListAnnotation - from .long_annotation import LongAnnotation - from .map_annotation import MapAnnotation - from .tag_annotation import TagAnnotation - from .term_annotation import TermAnnotation - from .timestamp_annotation import TimestampAnnotation - from .xml_annotation import XMLAnnotation - - _annotation_types: Dict[str, type] = { - "boolean_annotation": BooleanAnnotation, - "comment_annotation": CommentAnnotation, - "double_annotation": DoubleAnnotation, - "file_annotation": FileAnnotation, - "list_annotation": ListAnnotation, - "long_annotation": LongAnnotation, - "map_annotation": MapAnnotation, - "tag_annotation": TagAnnotation, - "term_annotation": TermAnnotation, - "timestamp_annotation": TimestampAnnotation, - "xml_annotation": XMLAnnotation, - } - """, - body=""" - @validator("structured_annotations", pre=True, each_item=True) - def validate_structured_annotations( - cls, value: Union[Annotation, Dict[Any, Any]] - ) -> Annotation: - if isinstance(value, Annotation): - return value - elif isinstance(value, dict): - try: - _type = value.pop("_type") - except KeyError: - raise ValueError( - "dict initialization requires _type" - ) from None - try: - annotation_cls = _annotation_types[_type] - except KeyError: - raise ValueError(f"unknown Annotation type '{_type}'") from None - return annotation_cls(**value) - else: - raise ValueError("invalid type for annotation values") - """, - ), - "TiffData/UUID": Override( - type_="Optional[UUID]", - default="None", - imports=""" - from typing import Optional - from .simple_types import UniversallyUniqueIdentifier - from ome_types._base_type import OMEType - - class UUID(OMEType): - file_name: str - value: UniversallyUniqueIdentifier - """, - ), - "M/K": Override(type_="str", default=""), -} - - -# Maps XSD TypeName to ClassOverride configuration, used to control dataclass -# generation. -CLASS_OVERRIDES = { - "OME": ClassOverride( - imports=""" - from typing import Any - import weakref - from ome_types import util - from pathlib import Path - """, - body=""" - def __init__(self, **data: Any) -> None: - super().__init__(**data) - self._link_refs() - - def _link_refs(self) -> None: - ids = util.collect_ids(self) - for ref in util.collect_references(self): - ref._ref = weakref.ref(ids[ref.id]) - - def __setstate__(self: Any, state: Dict[str, Any]) -> None: - '''Support unpickle of our weakref references.''' - super().__setstate__(state) - self._link_refs() - - @classmethod - def from_xml(cls, xml: Union[Path, str]) -> 'OME': - from ome_types import from_xml - return from_xml(xml) - - @classmethod - def from_tiff(cls, path: Union[Path, str]) -> 'OME': - from ome_types import from_tiff - return from_tiff(path) - - def to_xml(self) -> str: - from ome_types import to_xml - return to_xml(self) - """, - ), - "Reference": ClassOverride( - imports=""" - from pydantic import Field - from typing import Any, Optional, TYPE_CHECKING - from weakref import ReferenceType - from .simple_types import LSID - """, - fields=""" - if TYPE_CHECKING: - _ref: Optional["ReferenceType[OMEType]"] - - id: LSID - _ref = None - """, - # FIXME Could make `ref` abstract and implement stronger-typed overrides - # in subclasses. - body=""" - @property - def ref(self) -> Any: - if self._ref is None: - raise ValueError("references not yet resolved on root OME object") - return self._ref() - """, - ), - "XMLAnnotation": ClassOverride( - imports=""" - from xml.etree import ElementTree - from typing import Generator - from typing import Callable, Generator, Any, Dict - - class Element(ElementTree.Element): - '''ElementTree.Element that supports pydantic validation.''' - @classmethod - def __get_validators__(cls) -> Generator[Callable[[Any], Any], None, None]: - yield cls.validate - - @classmethod - def validate(cls, v: Any) -> ElementTree.Element: - if isinstance(v, ElementTree.Element): - return v - try: - return ElementTree.fromstring(v) - except ElementTree.ParseError as e: - raise ValueError(f"Invalid XML string: {e}") - """, - body=""" - # NOTE: pickling this object requires xmlschema>=1.4.1 - - def dict(self, **k: Any) -> Dict[str, Any]: - d = super().dict(**k) - d["value"] = ElementTree.tostring( - d.pop("value"), encoding="unicode", method="xml" - ).strip() - return d - """, - ), - "BinData": ClassOverride(base_type="object", fields="value: str"), - "Map": ClassOverride(fields_suppress={"K"}), - "M": ClassOverride(base_type="object", fields="value: str"), - "LightEmittingDiode": ClassOverride( - fields='kind: Literal["light_emitting_diode"] = "light_emitting_diode"', - imports="from typing_extensions import Literal", - ), - "Laser": ClassOverride( - fields='kind: Literal["laser"] = "laser"', - imports="from typing_extensions import Literal", - ), - "Arc": ClassOverride( - fields='kind: Literal["arc"] = "arc"', - imports="from typing_extensions import Literal", - ), - "Filament": ClassOverride( - fields='kind: Literal["filament"] = "filament"', - imports="from typing_extensions import Literal", - ), - "GenericExcitationSource": ClassOverride( - fields='kind: Literal["generic_excitation_source"] = "generic_excitation_source"', - imports="from typing_extensions import Literal", - ), - "Label": ClassOverride( - fields='kind: Literal["label"] = "label"', - imports="from typing_extensions import Literal", - ), - "Point": ClassOverride( - fields='kind: Literal["point"] = "point"', - imports="from typing_extensions import Literal", - ), - "Mask": ClassOverride( - fields='kind: Literal["mask"] = "mask"', - imports="from typing_extensions import Literal", - ), - "Rectangle": ClassOverride( - fields='kind: Literal["rectangle"] = "rectangle"', - imports="from typing_extensions import Literal", - ), - "Polygon": ClassOverride( - fields='kind: Literal["polygon"] = "polygon"', - imports="from typing_extensions import Literal", - ), - "Polyline": ClassOverride( - fields='kind: Literal["polyline"] = "polyline"', - imports="from typing_extensions import Literal", - ), - "Line": ClassOverride( - fields='kind: Literal["line"] = "line"', - imports="from typing_extensions import Literal", - ), - "Ellipse": ClassOverride( - fields='kind: Literal["ellipse"] = "ellipse"', - imports="from typing_extensions import Literal", - ), -} - - -def autoflake(text: str, **kwargs: Any) -> str: - kwargs.setdefault("remove_all_unused_imports", True) - kwargs.setdefault("remove_unused_variables", True) - return fix_code(text, **kwargs) - - -def black_format(text: str, line_length: int = 88) -> str: - return black.format_str(text, mode=black.FileMode(line_length=line_length)) - - -def sort_imports(text: str) -> str: - return isort.api.sort_code_string(text, profile="black", float_to_top=True) - - -def sort_types(el: XsdType) -> str: - if not el.is_complex() and not el.base_type.is_restriction(): - return " " + el.local_name.lower() - return el.local_name.lower() - - -def sort_prop(prop: Member) -> str: - return ("" if prop.default_val_str else " ") + prop.format().lower() - - -def as_identifier(s: str) -> str: - # Remove invalid characters - _s = re.sub("[^0-9a-zA-Z_]", "", s) - # Remove leading characters until we find a letter or underscore - _s = re.sub("^[^a-zA-Z_]+", "", _s) - if not _s: - raise ValueError(f"Could not clean {s}: nothing left") - return _s - - -CAMEL_SNAKE_OVERRIDES = {"ROIs": "rois"} - - -def camel_to_snake(name: str) -> str: - result = CAMEL_SNAKE_OVERRIDES.get(name, None) - if not result: - # https://stackoverflow.com/a/1176023 - result = re.sub("([A-Z]+)([A-Z][a-z]+)", r"\1_\2", name) - result = re.sub("([a-z0-9])([A-Z])", r"\1_\2", result) - result = result.lower().replace(" ", "_") - camel_snake_registry[name] = result - return result - - -def local_import(item_type: str) -> str: - return f"from .{camel_to_snake(item_type)} import {item_type}" - - -def get_docstring(component: XsdComponent | XsdType, summary: bool = False) -> str: - try: - doc = dedent(component.annotation.documentation[0].text).strip() - # some docstrings include a start ('*Word') which makes sphinx angry - # this line wraps those in backticks - doc = re.sub(r"(\*\w+)\s", r"`\1` ", doc) - # make sure the first line is followed by a double newline - # and preserve paragraphs - if summary: - doc = re.sub(r"\.\s", ".\n\n", doc, count=1) - # textwrap each paragraph seperately - paragraphs = ["\n".join(wrap(p.strip(), width=78)) for p in doc.split("\n\n")] - # join and return - return "\n\n".join(paragraphs) - except (AttributeError, IndexError): - return "" - - -def make_dataclass(component: XsdComponent | XsdType) -> list[str]: - class_override = CLASS_OVERRIDES.get(component.local_name, None) - lines = ["from ome_types._base_type import OMEType", ""] - if isinstance(component, XsdType): - base_type = component.base_type - else: - base_type = component.type.base_type - - if class_override and class_override.base_type: - if class_override.base_type == "object": - base_name = "(OMEType)" - else: - base_name = f"({class_override.base_type}, OMEType)" - base_type = None - elif base_type and not hasattr(base_type, "python_type"): - base_name = f"({base_type.local_name}, OMEType)" - if base_type.is_complex(): - lines += [local_import(base_type.local_name)] - else: - lines += [f"from .simple_types import {base_type.local_name}"] - else: - base_name = "(OMEType)" - if class_override and class_override.imports: - lines.append(class_override.imports) - - base_members = set() - _basebase = base_type - while _basebase: - base_members.update(set(iter_members(base_type))) - _basebase = _basebase.base_type - skip_names = set() - if class_override: - skip_names.update(class_override.fields_suppress) - - members = MemberSet( - m - for m in iter_members(component) - if m not in base_members and m.local_name not in skip_names - ) - lines += members.imports() - lines += members.locals() - - lines += [f"class {component.local_name}{base_name}:"] - doc = get_docstring(component, summary=True) - doc = MemberSet(iter_members(component)).docstring( - doc or f"{component.local_name}." - ) - doc = f'"""{doc.strip()}\n"""\n' - lines += indent(doc, " ").splitlines() - if class_override and class_override.fields: - lines.append(class_override.fields) - lines += members.lines(indent=1) - - if class_override and class_override.body: - lines.append(class_override.body) - lines += members.body() - - return lines - - -def make_abstract_class(component: XsdComponent) -> list[str]: - # FIXME: ? this might be a bit of an OME-schema-specific hack - # this seems to be how abstract is used in the OME schema - for e in component.iter_components(): - if e != component: - raise NotImplementedError( - "Don't yet know how to handle abstract class with sub-components" - ) - - subs = [ - el - for el in component.schema.elements.values() - if el.substitution_group == component.name - ] - - if not subs: - raise NotImplementedError( - "Don't know how to handle abstract class without substitutionGroups" - ) - - for el in subs: - if not el.type.is_extension() and el.type.base_type == component.type: - raise NotImplementedError( - "Expected all items in substitution group to extend " - f"the type {component.type} of Abstract element {component}" - ) - - sub_names = [el.local_name for el in subs] - lines = ["from typing import Union", *[local_import(n) for n in sub_names]] - lines += [local_import(component.type.local_name)] - lines += [f"{component.local_name} = {component.type.local_name}", ""] - lines += [f"{component.local_name}Type = Union[{', '.join(sub_names)}]"] - return lines - - -def make_enum(component: XsdComponent) -> list[str]: - name = component.local_name - _type = component.type if hasattr(component, "type") else component - if _type.is_list(): - _type = _type.item_type - lines = ["from enum import Enum", ""] - lines += [f"class {name}(Enum):"] - doc = get_docstring(component, summary=True) - if doc: - if not doc.endswith("."): - doc += "." - doc = f'"""{doc}\n"""\n' - lines += indent(doc, " ").splitlines() - enum_elems = list(_type.elem.iter("enum")) - facets = _type.get_facet(qnames.XSD_ENUMERATION) - members: list[tuple[str, str]] = [] - if enum_elems: - for el, value in zip(enum_elems, facets.enumeration): - _name = el.attrib["enum"] - if _type.base_type.python_type.__name__ == "str": - value = f'"{value}"' - members.append((_name, value)) - else: - for e in facets.enumeration: - members.append((camel_to_snake(e), repr(e))) - - for n, v in sorted(members): - lines.append(f" {as_identifier(n).upper()} = {v}") - return lines - - -def make_color() -> list[str]: - color = """ - from pydantic import color - class Color(color.Color): - def __init__(self, val: color.ColorType) -> None: - if isinstance(val, int): - val = self._int2tuple(val) - super().__init__(val) - - @classmethod - def _int2tuple(cls, val: int): - return (val >> 24 & 255, val >> 16 & 255, val >> 8 & 255, (val & 255) / 255) - - def as_int32(self) -> int: - r, g, b, *a = self.as_rgb_tuple() - v = r << 24 | g << 16 | b << 8 | int((a[0] if a else 1) * 255) - if v < 2 ** 32 // 2: - return v - return v - 2 ** 32 - - def __eq__(self, o: object) -> bool: - if isinstance(o, Color): - return self.as_int32() == o.as_int32() - return False - - def __int__(self) -> int: - return self.as_int32() - """ - return dedent(color).strip().splitlines() - - -facet_converters = { - qnames.XSD_PATTERN: lambda f: [f"regex = re.compile(r'{f.regexps[0]}')"], - qnames.XSD_MIN_INCLUSIVE: lambda f: [f"ge = {f.value}"], - qnames.XSD_MIN_EXCLUSIVE: lambda f: [f"gt = {f.value}"], - qnames.XSD_MAX_INCLUSIVE: lambda f: [f"le = {f.value}"], - qnames.XSD_MAX_EXCLUSIVE: lambda f: [f"lt = {f.value}"], - qnames.XSD_LENGTH: lambda f: [f"min_length = {f.value}", f"max_length = {f.value}"], - qnames.XSD_MIN_LENGTH: lambda f: [f"min_length = {f.value}"], - qnames.XSD_MAX_LENGTH: lambda f: [f"max_length = {f.value}"], -} - - -def iter_all_members( - component: XsdComponent, -) -> Generator[XsdElement | XsdAttribute, None, None]: - for c in component.iter_components((XsdElement, XsdAttribute)): - if c is component: - continue - yield c - - -def iter_members( - component: XsdElement | XsdType, -) -> Generator[XsdElement | XsdAttribute, None, None]: - if isinstance(component, XsdElement): - for attr in component.attributes.values(): - if isinstance(attr, XsdAttribute): - yield attr - yield from component.iterchildren() - else: - yield from iter_all_members(component) - - -def is_enum_type(obj: XsdType) -> bool: - """Return true if XsdType represents an enumeration.""" - return obj.get_facet(qnames.XSD_ENUMERATION) is not None - - -class Member: - def __init__(self, component: XsdElement | XsdAttribute): - self.component = component - if component.is_global(): - raise ValueError(f"global component {component!r} not allowed") - - @property - def identifier(self) -> str: - if isinstance(self.component, (XsdAnyElement, XsdAnyAttribute)): - return self.component.local_name - name = camel_to_snake(self.component.local_name) - if self.plural: - plural = camel_to_snake(self.plural) - plural_registry[(self.parent_name, name)] = plural - name = plural - if not name.isidentifier(): - raise ValueError(f"failed to make identifier of {self!r}") - return name - - @property - def plural(self) -> str | None: - """Plural form of component name, if available.""" - if ( - isinstance(self.component, XsdElement) - and self.component.is_multiple() - and self.component.ref - and self.component.ref.annotation - ): - appinfo = self.component.ref.annotation.appinfo - if len(appinfo) != 1: - raise ValueError("unexpected multiple appinfo elements") - plural = appinfo[0].find("xsdfu/plural") - if plural is not None: - return plural.text - return None - - @property - def type(self) -> XsdType: - return self.component.type - - def to_numpydoc_param(self) -> Parameter: - _type = self.type_string - _type += ", optional" if self.is_optional else "" - desc = get_docstring(self.component) - desc = re.sub(r"\s?\[.+\]", "", desc) # remove bracketed types - return Parameter(self.identifier, _type, wrap(desc)) - - @property - def is_builtin_type(self) -> bool: - return isinstance(self.type, XsdAtomicBuiltin) - - @property - def is_decimal(self) -> bool: - return self.component.type.is_derived( - self.component.schema.builtin_types()["decimal"] - ) - - @property - def is_nonref_id(self) -> bool: - """Return True for 'id' fields that aren't part of a Reference type.""" - if self.identifier != "id": - return False - # Walk up the containment tree until we find something with a base_type. - p = self.component.parent - while p is not None and not hasattr(p, "base_type"): - p = p.parent - if p is not None: - # Walk the type hierarchy looking for 'Reference'. - pt = p.base_type - while pt is not None: - if pt.local_name == "Reference": - return False - pt = pt.base_type - # If we get here, we have an 'id' that isn't in a Reference type. - return True - - @property - def is_ref_id(self) -> bool: - if self.identifier == "id": - return not self.is_nonref_id - return False - - @property - def parent_name(self) -> str: - """Local name of component's first named ancestor.""" - p = self.component.parent - while not p.local_name and p.parent is not None: - p = p.parent - return p.local_name - - @property - def key(self) -> str: - name = f"{self.parent_name}/{self.component.local_name}" - if name not in OVERRIDES and self.component.local_name in OVERRIDES: - return self.component.local_name - return name - - def locals(self) -> list[str]: - if self.key in OVERRIDES: - return [] - if isinstance(self.component, (XsdAnyElement, XsdAnyAttribute)): - return [] - if not self.type or self.type.is_global(): - return [] - locals_: list[str] = [] - # FIXME: this bit is mostly hacks - if self.type.is_complex() and self.component.ref is None: - locals_.append("\n".join(make_dataclass(self.component)) + "\n") - if self.type.is_restriction() and is_enum_type(self.type): - locals_.append("\n".join(make_enum(self.component)) + "\n") - if self.type.is_list() and is_enum_type(self.type.item_type): - locals_.append("\n".join(make_enum(self.component)) + "\n") - return locals_ - - def imports(self) -> list[str]: - if self.key in OVERRIDES: - _imp = OVERRIDES[self.key].imports - return [_imp] if _imp else [] - if isinstance(self.component, (XsdAnyElement, XsdAnyAttribute)): - return ["from typing import Any"] - imports = [] - if not self.max_occurs: - imports.append("from typing import List") - if self.is_optional: - imports.append("from pydantic import Field") - elif self.is_optional: - imports.append("from typing import Optional") - if self.is_decimal: - imports.append("from typing import cast") - if self.type.is_datetime(): - imports.append("from datetime import datetime") - if not self.is_builtin_type and self.type.is_global(): - # FIXME: hack - if not self.type.local_name == "anyType": - if self.type.is_complex(): - imports.append(local_import(self.type.local_name)) - else: - imports.append(f"from .simple_types import {self.type.local_name}") - - if self.component.ref is not None: - if self.component.ref.local_name not in OVERRIDES: - imports.append(local_import(self.component.ref.local_name)) - - return imports - - def body(self) -> str: - if self.key in OVERRIDES: - return OVERRIDES[self.key].body or "" - return "" - - @property - def type_string(self) -> str: - """Single type, without Optional, etc...""" - if self.key in OVERRIDES: - return OVERRIDES[self.key].type_ - if isinstance(self.component, (XsdAnyElement, XsdAnyAttribute)): - return "Any" - if self.component.ref is not None: - if not self.component.ref.is_global(): - raise ValueError("local ref not supported") - return self.component.ref.local_name - - if self.type.is_datetime(): - return "datetime" - if self.is_builtin_type: - return self.type.python_type.__name__ - - if self.type.is_global(): - return self.type.local_name - elif self.type.is_complex() or self.type.is_list(): - return self.component.local_name - - if self.type.is_restriction(): - # enumeration - enum = self.type.get_facet(qnames.XSD_ENUMERATION) - if enum: - return self.component.local_name - if self.type.base_type.local_name == "string": - return "str" - return "" - - @property - def full_type_string(self) -> str: - """Full type, like Optional[List[str]].""" - if self.key in OVERRIDES and self.type_string: - return f": {self.type_string}" - type_string = self.type_string - if not type_string: - return "" - if not self.max_occurs: - LISTS[self.parent_name].add(self.component.local_name) - - type_string = f"List[{type_string}]" - elif self.is_optional: - type_string = f"Optional[{type_string}]" - return f": {type_string}" if type_string else "" - - @property - def default_val_str(self) -> str: - if self.key in OVERRIDES: - default = OVERRIDES[self.key].default - return f" = {default}" if default else "" - elif not self.is_optional: - return "" - - if not self.max_occurs: - default_val = "Field(default_factory=list)" - else: - default_val = self.component.default - if default_val is not None: - if is_enum_type(self.type): - default_val = f"{self.type_string}('{default_val}')" - elif hasattr(builtins, self.type_string): - default_val = repr(getattr(builtins, self.type_string)(default_val)) - if self.type_string == "Color": - default_val = "Color('white')" - elif self.is_decimal: - default_val = f"cast({self.type_string}, {default_val})" - else: - default_val = "None" - return f" = {default_val}" - - @property - def max_occurs(self) -> int | None: - default = None if self.type.is_list() else 1 - return getattr(self.component, "max_occurs", default) - - @property - def is_optional(self) -> bool: - if self.is_ref_id: - return False - if getattr(self.component.parent, "model", "") == "choice": - return True - if hasattr(self.component, "min_occurs"): - return self.component.min_occurs == 0 - return self.component.is_optional() - - def __repr__(self) -> str: - type_ = "element" if isinstance(self.component, XsdElement) else "attribute" - return f"" - - def format(self) -> str: - return f"{self.identifier}{self.full_type_string}{self.default_val_str}" - - -class MemberSet: - def __init__(self, initial: Iterable[Member] = ()): - # Use a list to maintain insertion order. - self._members: list[Member] = [] - self.update(initial) - - def add(self, member: Member) -> None: - if not isinstance(member, Member): - member = Member(member) - # We don't expect very many elements so this O(n) check is fine. - if member in self._members: - return - self._members.append(member) - - def update(self, members: Iterable[Member]) -> None: - for member in members: - self.add(member) - - def lines(self, indent: int = 1) -> list[str]: - if not self._members: - lines = [" " * indent + "pass"] - else: - lines = [ - " " * indent + m.format() - for m in sorted(self._members, key=sort_prop) - ] - return lines - - def imports(self) -> list[str]: - return list(chain.from_iterable(m.imports() for m in self._members)) - - def locals(self) -> list[str]: - return list(chain.from_iterable(m.locals() for m in self._members)) - - def body(self) -> list[str]: - return [m.body() for m in self._members] - - def has_non_default_args(self) -> bool: - return any(not m.default_val_str for m in self._members) - - def has_nonref_id(self) -> bool: - return any(m.is_nonref_id for m in self._members) - - @property - def non_defaults(self) -> MemberSet: - return MemberSet(m for m in self._members if not m.default_val_str) - - def __iter__(self) -> Iterator[Member]: - return iter(self._members) - - def docstring(self, summary: str = "") -> str: - ds = NumpyDocString(summary) - ds["Parameters"] = [ - m.to_numpydoc_param() for m in sorted(self._members, key=sort_prop) - ] - return str(ds) - - -class GlobalElem: - def __init__(self, elem: XsdElement | XsdType): - if not elem.is_global(): - raise ValueError("Element must be global") - self.elem = elem - - @property - def type(self) -> XsdType: - return self.elem if self.is_type else self.elem.type - - @property - def is_complex(self) -> bool: - if hasattr(self.type, "is_complex"): - return self.type.is_complex() - return False - - @property - def is_element(self) -> bool: - return isinstance(self.elem, XsdElement) - - @property - def is_type(self) -> bool: - return isinstance(self.elem, XsdType) - - @property - def is_enum(self) -> bool: - is_enum = bool(self.elem.get_facet(qnames.XSD_ENUMERATION) is not None) - if is_enum: - if not len(self.elem.facets) == 1: - raise NotImplementedError("Unexpected enum with multiple facets") - return is_enum - - def _simple_class(self) -> list[str]: - if self.type.local_name == "Color": - return make_color() - if self.is_enum: - return make_enum(self.elem) - # Hack for xmlschema > 1.4.1 - if self.type.local_name == "base64Binary": - return ["class base64Binary(ConstrainedStr):", " pass"] - if self.type.local_name == "Hex40": - return [ - "class Hex40(ConstrainedStr):", - " min_length = 40", - " max_length = 40", - ] - - lines = [] - if self.type.base_type.is_restriction(): - parent = self.type.base_type.local_name - else: - # it's a restriction of a builtin - pytype = self.elem.base_type.python_type.__name__ - parent = f"Constrained{pytype.title()}" - lines.extend([f"from pydantic.types import {parent}", ""]) - lines.append(f"class {self.elem.local_name}({parent}):") - - members = [] - for key, facet in self.elem.facets.items(): - members.extend([f" {line}" for line in facet_converters[key](facet)]) - lines.extend(members if members else [" pass"]) - if any("re.compile" in m for m in members): - lines = ["import re", "", *lines] - return lines - - def lines(self) -> str: - if not self.is_complex: - lines = self._simple_class() - elif self.elem.abstract: - lines = make_abstract_class(self.elem) - else: - lines = make_dataclass(self.elem) - return "\n".join(lines) - - def format(self) -> str: - return black_format(sort_imports(autoflake(self.lines() + "\n"))) - - def write(self, filename: str) -> None: - os.makedirs(os.path.dirname(filename), exist_ok=True) - with open(filename, "w", encoding="utf-8") as f: - f.write(self.format()) - - @property - def fname(self) -> str: - return f"{camel_to_snake(self.elem.local_name)}.py" - - -_this_dir = os.path.dirname(__file__) -_url = os.path.join(_this_dir, "ome_types", "ome-2016-06.xsd") -_target = os.path.join(_this_dir, "ome_types", "model") - -AnnotationID_Override = r"""class AnnotationID(LSID): - regex = re.compile( - r"(urn:lsid:([\w\-\.]+\.[\w\-\.]+)+:Annotation:\S+)|(Annotation:\S+)" - ) - - @classmethod - def __get_validators__(cls): - yield cls.validate - - @classmethod - def validate(cls, v) -> "AnnotationID": - if not cls.regex.match(v): - search = cls.regex.search(v) - if search: - import warnings - - new_v = search.group() - warnings.warn(f"Casting invalid AnnotationID {v!r} to {new_v!r}") - v = new_v - return v -""" - -_SIMPLE_OVERRIDES = {"AnnotationID": AnnotationID_Override} - - -def convert_schema(url: str = _url, target_dir: str = _target) -> None: - print("Inspecting XML schema ...") - if isinstance(url, Path): - url = str(url) - schema = XMLSchema(url) - print("Building model ...") - shutil.rmtree(target_dir, ignore_errors=True) - init_imports = [] - simples: list[GlobalElem] = [] - for elem in sorted(schema.types.values(), key=sort_types): - if elem.local_name in OVERRIDES: - continue - converter = GlobalElem(elem) - if not elem.is_complex(): - simples.append(converter) - continue - targetfile = os.path.join(target_dir, converter.fname) - init_imports.append((converter.fname, elem.local_name)) - converter.write(filename=targetfile) - - for elem in schema.elements.values(): - if elem.local_name in OVERRIDES: - continue - converter = GlobalElem(elem) - targetfile = os.path.join(target_dir, converter.fname) - init_imports.append((converter.fname, elem.local_name)) - converter.write(filename=targetfile) - - text = "\n".join( - [_SIMPLE_OVERRIDES.get(s.elem.local_name) or s.format() for s in simples] - ) - text = black_format(sort_imports(text)) - with open(os.path.join(target_dir, "simple_types.py"), "w", encoding="utf-8") as f: - f.write(text) - - text = "" - for _fname, classname in init_imports: - text += local_import(classname) + "\n" - text = sort_imports(text) - text += f"\n\n__all__ = [{', '.join(sorted(repr(i[1]) for i in init_imports))}]" - # FIXME These could probably live somewhere else less visible to end-users. - if len(plural_registry) != len(set(plural_registry.values())): - raise Exception( - "singular-to-plural mapping is not invertible (duplicate plurals)" - ) - text += "\n\n_singular_to_plural = " + repr( - {k: plural_registry[k] for k in sorted(plural_registry)} - ) - text += "\n\n_plural_to_singular = " + repr( - {plural_registry[k]: k[1] for k in sorted(plural_registry)} - ) - if len(camel_snake_registry) != len(set(camel_snake_registry.values())): - raise Exception("camel-to-snake mapping is not invertible (duplicate snakes)") - text += "\n\n_camel_to_snake = " + repr( - {k: camel_snake_registry[k] for k in sorted(camel_snake_registry)} - ) - text += "\n\n_snake_to_camel = " + repr( - {camel_snake_registry[k]: k for k in sorted(camel_snake_registry)} - ) - text += "\n\n_lists = " + repr(dict(LISTS)) - text = black_format(text) - with open(os.path.join(target_dir, "__init__.py"), "w", encoding="utf-8") as f: - f.write(text) - - -if __name__ == "__main__": - # for testing - convert_schema() diff --git a/src/ome_autogen/__init__.py b/src/ome_autogen/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/ome_autogen/__main__.py b/src/ome_autogen/__main__.py new file mode 100644 index 00000000..b22c2944 --- /dev/null +++ b/src/ome_autogen/__main__.py @@ -0,0 +1,3 @@ +from ome_autogen.main import build_model + +build_model() diff --git a/src/ome_autogen/_config.py b/src/ome_autogen/_config.py new file mode 100644 index 00000000..8789bd33 --- /dev/null +++ b/src/ome_autogen/_config.py @@ -0,0 +1,65 @@ +from __future__ import annotations + +from xsdata.codegen.writer import CodeWriter +from xsdata.models import config as cfg +from xsdata.utils import text + +from ome_autogen._generator import OmeGenerator +from ome_autogen._util import camel_to_snake + +MIXIN_MODULE = "ome_types._mixins" +MIXINS: list[tuple[str, str, bool]] = [ + (".*", f"{MIXIN_MODULE}._base_type.OMEType", False), # base type on every class + ("OME", f"{MIXIN_MODULE}._ome.OMEMixin", True), + ("Instrument", f"{MIXIN_MODULE}._instrument.InstrumentMixin", False), + ("Reference", f"{MIXIN_MODULE}._reference.ReferenceMixin", True), + ("BinData", f"{MIXIN_MODULE}._bin_data.BinDataMixin", True), + ("Pixels", f"{MIXIN_MODULE}._pixels.PixelsMixin", True), +] + +ALLOW_RESERVED_NAMES = {"type", "Type", "Union"} +OME_FORMAT = "OME" + + +def get_config( + package: str, kw_only: bool = True, compound_fields: bool = False +) -> cfg.GeneratorConfig: + # ALLOW "type" to be used as a field name + text.stop_words.difference_update(ALLOW_RESERVED_NAMES) + + # use our own camel_to_snake + # Our's interprets adjacent capital letters as two words + # NameCase.SNAKE: 'PositionXUnit' -> 'position_xunit' + # camel_to_snake: 'PositionXUnit' -> 'position_x_unit' + cfg.__name_case_func__["snakeCase"] = camel_to_snake + + # critical to be able to use the format="OME" + CodeWriter.register_generator(OME_FORMAT, OmeGenerator) + + mixins = [] + for class_name, import_string, prepend in MIXINS: + mixins.append( + cfg.GeneratorExtension( + type=cfg.ExtensionType.CLASS, + class_name=class_name, + import_string=import_string, + prepend=prepend, + ) + ) + + keep_case = cfg.NameConvention(cfg.NameCase.ORIGINAL, "type") + return cfg.GeneratorConfig( + output=cfg.GeneratorOutput( + package=package, + # format.value lets us use our own generator + # kw_only is important, it makes required fields actually be required + format=cfg.OutputFormat(value=OME_FORMAT, kw_only=kw_only), + structure_style=cfg.StructureStyle.CLUSTERS, + docstring_style=cfg.DocstringStyle.NUMPY, + compound_fields=cfg.CompoundFields(enabled=compound_fields), + ), + # Add our mixins + extensions=cfg.GeneratorExtensions(mixins), + # Don't convert things like XMLAnnotation to XmlAnnotation + conventions=cfg.GeneratorConventions(class_name=keep_case), + ) diff --git a/src/ome_autogen/_generator.py b/src/ome_autogen/_generator.py new file mode 100644 index 00000000..7175e8ca --- /dev/null +++ b/src/ome_autogen/_generator.py @@ -0,0 +1,170 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, NamedTuple + +from xsdata.formats.dataclass.filters import Filters +from xsdata.formats.dataclass.generator import DataclassGenerator + +from ome_autogen import _util +from xsdata_pydantic_basemodel.generator import PydanticBaseFilters + +if TYPE_CHECKING: + from xsdata.codegen.models import Attr, Class + from xsdata.codegen.resolver import DependenciesResolver + from xsdata.models.config import GeneratorConfig + + +# from ome_types._mixins._base_type import AUTO_SEQUENCE +# avoiding import to avoid build-time dependency on the ome-types package +AUTO_SEQUENCE = "__auto_sequence__" + + +class Override(NamedTuple): + element_name: str # name of the attribute in the XSD + class_name: str # name of our override class + module_name: str | None # module where the override class is defined + + +CLASS_OVERRIDES = [ + Override("FillColor", "Color", "ome_types.model._color"), + Override("StrokeColor", "Color", "ome_types.model._color"), + Override("Color", "Color", "ome_types.model._color"), + Override("Union", "ShapeUnion", "ome_types.model._shape_union"), + Override( + "StructuredAnnotations", + "StructuredAnnotationList", + "ome_types.model._structured_annotations", + ), +] +# classes that should never be optional, but always have default_factories +NO_OPTIONAL = {"Union", "StructuredAnnotations"} + +# if these names are found as default=..., turn them into default_factory=... +FACTORIZE = set([x.class_name for x in CLASS_OVERRIDES] + ["StructuredAnnotations"]) + +# prebuilt maps for usage in code below +OVERRIDE_ELEM_TO_CLASS = {o.element_name: o.class_name for o in CLASS_OVERRIDES} +IMPORT_PATTERNS = { + o.module_name: { + o.class_name: [f": {o.class_name} =", f": Optional[{o.class_name}] ="] + } + for o in CLASS_OVERRIDES + if o.module_name +} + + +class OmeGenerator(DataclassGenerator): + @classmethod + def init_filters(cls, config: GeneratorConfig) -> Filters: + return OmeFilters(config) + + def render_module( + self, resolver: DependenciesResolver, classes: list[Class] + ) -> str: + mod = super().render_module(resolver, classes) + + # xsdata renames classes like "FillRule" (which appears as a SimpleType + # inside of the Shape ComlexType) as "Shape_FillRule". + # We want to make them available as "FillRule" in the corresponding + # module, (i.e. the "Shape" module in this case). + # That is, we want "Shape = Shape_FillRule" included in the module. + # this is for backwards compatibility. + aliases = [] + for c in classes: + for i in resolver.imports: + if f"{c.name}_" in i.qname: + # import_name is something like 'Shape_FillRule' + import_name = i.qname.rsplit("}", 1)[-1] + # desired alias is just 'FillRule' + alias = import_name.split(f"{c.name}_")[-1] + aliases.append(f"{alias} = {import_name}") + + # we also want inner (nested) classes to be available at the top level + # e.g. map.Map.M -> map.M + for inner in c.inner: + aliases.append(f"{inner.name} = {c.name}.{inner.name}") + + if aliases: + mod += "\n\n" + "\n".join(aliases) + "\n" + + return mod + + +class OmeFilters(PydanticBaseFilters): + def __init__(self, config: GeneratorConfig): + super().__init__(config) + + # TODO: it would be nice to know how to get the schema we're processing from + # the config. For now, we just assume it's the OME schema and that's the + # hardcoded default in _util.get_appinfo + self.appinfo = _util.get_appinfo() + + def class_bases(self, obj: Class, class_name: str) -> list[str]: + # we don't need PydanticBaseFilters to add the Base class + # because we add it in the config.extensions + # This could go once PydanticBaseFilters is better about deduping + return Filters.class_bases(self, obj, class_name) + + def _attr_is_optional(self, attr: Attr) -> bool: + if attr.name in NO_OPTIONAL: + return False + return attr.is_nillable or ( + attr.default is None and (attr.is_optional or not self.format.kw_only) + ) + + def _format_type(self, attr: Attr, result: str) -> str: + if self._attr_is_optional(attr): + return f"None | {result}" if self.union_type else f"Optional[{result}]" + return result + + def field_type(self, attr: Attr, parents: list[str]) -> str: + if attr.is_list: + # HACK + # It would be nicer to put this in the self.field_name method...but that + # method only receives the attr name, not the attr object, and so we + # don't know at that point whether it belongs to a list or not. + # This hack works only because this method is called BEFORE self.field_name + # in the class.jinja2 template, so we directly modify the attr object here. + attr.name = self.appinfo.plurals.get(attr.name, f"{attr.name}s") + + if attr.name in OVERRIDE_ELEM_TO_CLASS: + return self._format_type(attr, OVERRIDE_ELEM_TO_CLASS[attr.name]) + + type_name = super().field_type(attr, parents) + # we want to use datetime.datetime instead of XmlDateTime + return type_name.replace("XmlDateTime", "datetime") + + @classmethod + def build_import_patterns(cls) -> dict[str, dict]: + patterns = super().build_import_patterns() + patterns.update(IMPORT_PATTERNS) + patterns["ome_types._mixins._util"] = {"new_uuid": ["default_factory=new_uuid"]} + patterns["datetime"] = {"datetime": ["datetime"]} + return {key: patterns[key] for key in sorted(patterns)} + + def field_default_value(self, attr: Attr, ns_map: dict | None = None) -> str: + if attr.tag == "Attribute" and attr.name == "ID": + return repr(AUTO_SEQUENCE) + for override in CLASS_OVERRIDES: + if attr.name == override.element_name: + if not self._attr_is_optional(attr): + return override.class_name + return super().field_default_value(attr, ns_map) + + def format_arguments(self, kwargs: dict, indent: int = 0) -> str: + # keep default_factory at the front + if kwargs.get("default") in FACTORIZE: + kwargs = {"default_factory": kwargs.pop("default"), **kwargs} + + # uncomment this to use new_uuid as the default_factory for all UUIDs + # but then we have an equality checking problem in the tests + # if kwargs.get("metadata", {}).get("pattern", "").startswith("(urn:uuid:"): + # kwargs.pop("default", None) + # kwargs = {"default_factory": "new_uuid", **kwargs} + return super().format_arguments(kwargs, indent) + + def constant_name(self, name: str, class_name: str) -> str: + if class_name in self.appinfo.enums: + # use the enum names found in appinfo/xsdfu/enum + return self.appinfo.enums[class_name][name].enum + return super().constant_name(name, class_name) diff --git a/src/ome_autogen/_transformer.py b/src/ome_autogen/_transformer.py new file mode 100644 index 00000000..9de8f0cf --- /dev/null +++ b/src/ome_autogen/_transformer.py @@ -0,0 +1,36 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from xsdata.codegen.analyzer import ClassAnalyzer +from xsdata.codegen.container import ClassContainer +from xsdata.codegen.handlers import RenameDuplicateAttributes +from xsdata.codegen.transformer import SchemaTransformer + +if TYPE_CHECKING: + from xsdata.codegen.models import Class + + +# we don't need RenameDuplicateAttributes because we inject +# proper enum names in our _generator.py +UNWANTED_HANDLERS = [(RenameDuplicateAttributes, None)] + + +class OMETransformer(SchemaTransformer): + # overriding to remove the certain handlers + def analyze_classes(self, classes: list[Class]) -> list[Class]: + """Analyzer the given class list and simplify attributes and extensions.""" + # xsdata makes this particular class hard to extend/modify + container = ClassContainer(config=self.config) + + for handlers in container.processors.values(): + for idx, h in enumerate(list(handlers)): + for unwanted, wanted in UNWANTED_HANDLERS: + if isinstance(h, unwanted): + handlers.remove(h) + if wanted is not None: + handlers.insert(idx, wanted(container)) + + container.extend(classes) + + return ClassAnalyzer.process(container) diff --git a/src/ome_autogen/_util.py b/src/ome_autogen/_util.py new file mode 100644 index 00000000..fbb99b22 --- /dev/null +++ b/src/ome_autogen/_util.py @@ -0,0 +1,91 @@ +from __future__ import annotations + +import os +import re +from collections import defaultdict +from contextlib import contextmanager +from functools import lru_cache +from pathlib import Path +from typing import Any, Iterator, NamedTuple, cast +from xml.etree import ElementTree as ET + +SRC_PATH = Path(__file__).parent.parent +SCHEMA_FILE = (SRC_PATH / "ome_types" / "ome-2016-06.xsd").absolute() + + +@contextmanager +def cd(new_path: str | Path) -> Iterator[None]: + """Temporarily change the current working directory. + + Used as a workaround for xsdata not supporting output path. + """ + prev = Path.cwd() + os.chdir(Path(new_path).expanduser().absolute()) + try: + yield + finally: + os.chdir(prev) + + +class EnumInfo(NamedTuple): + name: str + plural: str + unitsystem: str + enum: str + + +class AppInfo(NamedTuple): + plurals: dict[str, str] + enums: dict[str, dict[str, EnumInfo]] + abstract: list[str] + + +@lru_cache(maxsize=None) +def get_appinfo(schema: Path | str = SCHEMA_FILE) -> AppInfo: + """Gather all the stuff from the schema. + + xsdata doesn't try to do anything with it. But we want to use it to + provide better enum and plural names + """ + tree = ET.parse(schema) # noqa: S314 + plurals: dict[str, str] = {} + enums: defaultdict[str, dict[str, EnumInfo]] = defaultdict(dict) + in_name = "" + in_value = "" + abstract = [] + for node in tree.iter(): + if node.tag == "plural": # + plurals[in_name] = cast(str, node.text) + elif node.tag == "enum": # + enums[in_name][in_value] = EnumInfo( + node.get("name", ""), + node.get("plural", ""), + node.get("unitsystem", ""), + node.get("enum", ""), + ) + elif node.tag == "abstract": # + abstract.append(in_name) + else: + in_name = node.get("name", in_name) + in_value = node.get("value", in_value) + + return AppInfo(plurals, dict(enums), abstract) + + +CAMEL_SNAKE_OVERRIDES = {"ROIs": "rois"} +camel_snake_registry: dict[str, str] = {} + + +def camel_to_snake(name: str, **kwargs: Any) -> str: + """Variant of camel_to_snake that preserves adjacent uppercase letters. + + https://stackoverflow.com/a/1176023 + """ + name = name.lstrip("@") # remove leading @ from "@any_element" + result = CAMEL_SNAKE_OVERRIDES.get(name) + if not result: + result = re.sub("([A-Z]+)([A-Z][a-z]+)", r"\1_\2", name) + result = re.sub("([a-z0-9])([A-Z])", r"\1_\2", result) + result = result.lower().replace(" ", "_") + camel_snake_registry[name] = result + return result diff --git a/src/ome_autogen/main.py b/src/ome_autogen/main.py new file mode 100644 index 00000000..8f99832a --- /dev/null +++ b/src/ome_autogen/main.py @@ -0,0 +1,89 @@ +from __future__ import annotations + +import os +import subprocess +import sys +from pathlib import Path +from shutil import rmtree + +from ome_autogen import _util +from ome_autogen._config import get_config +from ome_autogen._transformer import OMETransformer + +OUTPUT_PACKAGE = "ome_types._autogenerated.ome_2016_06" +DO_MYPY = os.environ.get("OME_AUTOGEN_MYPY", "0") == "1" or "--mypy" in sys.argv +SRC_PATH = Path(__file__).parent.parent +SCHEMA_FILE = (SRC_PATH / "ome_types" / "ome-2016-06.xsd").absolute() +RUFF_IGNORE: list[str] = [ + "D101", # Missing docstring in public class + "D106", # Missing docstring in public nested class + "D205", # 1 blank line required between summary line and description + "D404", # First word of the docstring should not be This + "E501", # Line too long + "S105", # Possible hardcoded password +] + + +def build_model( + output_dir: Path | str = SRC_PATH, + schema_file: Path | str = SCHEMA_FILE, + target_package: str = OUTPUT_PACKAGE, + ruff_ignore: list[str] = RUFF_IGNORE, + do_formatting: bool = True, + do_mypy: bool = DO_MYPY, +) -> None: + """Convert the OME schema to a python model.""" + config = get_config(target_package) + transformer = OMETransformer(print=False, config=config) + + _print_gray(f"Processing {getattr(schema_file ,'name', schema_file)}...") + transformer.process_sources([Path(schema_file).resolve().as_uri()]) + + package_dir = str(Path(output_dir) / OUTPUT_PACKAGE.replace(".", "/")) + rmtree(package_dir, ignore_errors=True) + with _util.cd(output_dir): # xsdata doesn't support output path + _print_gray("Writing Files...") + transformer.process_classes() + + if do_formatting: + _fix_formatting(package_dir, ruff_ignore) + + if do_mypy: + _check_mypy(package_dir) + + _print_green(f"OME python model created at {OUTPUT_PACKAGE}") + + +def _fix_formatting(package_dir: str, ruff_ignore: list[str] = RUFF_IGNORE) -> None: + _print_gray("Running black and ruff ...") + + black = ["black", package_dir, "-q", "--line-length=88"] + subprocess.check_call(black) # noqa S + + ruff = ["ruff", "-q", "--fix", package_dir] + ruff.extend(f"--ignore={ignore}" for ignore in ruff_ignore) + subprocess.check_call(ruff) # noqa S + + +def _check_mypy(package_dir: str) -> None: + _print_gray("Running mypy ...") + + mypy = ["mypy", package_dir, "--strict"] + try: + subprocess.check_output(mypy, stderr=subprocess.STDOUT) # noqa S + except subprocess.CalledProcessError as e: + raise RuntimeError(f"mypy errors:\n\n{e.output.decode()}") from e + + +def _print_gray(text: str) -> None: + if os.name != "nt": + # UnicodeEncodeError: 'charmap' codec can't encode character '\u2713' + text = f"\033[90m\033[1m{text}\033[0m" + print(text) + + +def _print_green(text: str) -> None: + if os.name != "nt": + # UnicodeEncodeError: 'charmap' codec can't encode character '\u2713' + text = f"\033[92m\033[1m{text}\033[0m" + print(text) diff --git a/src/ome_types/__init__.py b/src/ome_types/__init__.py index dc235ce3..4768500a 100644 --- a/src/ome_types/__init__.py +++ b/src/ome_types/__init__.py @@ -1,30 +1,21 @@ -from typing import Any - -from ._units import ureg - try: - from ._version import version as __version__ + from importlib.metadata import PackageNotFoundError, version except ImportError: - __version__ = "unknown" + from importlib_metadata import PackageNotFoundError, version # type: ignore try: - from . import model - from .model import OME -except ModuleNotFoundError as e: - raise ModuleNotFoundError( - "Could not import 'ome_types.model.OME'.\nIf you are in a dev environment, " - "you may need to run 'python -m src.ome_autogen'" + str(e) - ) from None + __version__ = version("ome-types") +except PackageNotFoundError: + __version__ = "unknown" -from ._convenience import ( # isort:skip - from_tiff, - from_xml, - to_dict, - to_xml, - validate_xml, -) +from ome_types import model +from ome_types._conversion import from_tiff, from_xml, to_dict, to_xml +from ome_types.model import OME +from ome_types.units import ureg +from ome_types.validation import validate_xml __all__ = [ + "__version__", "from_tiff", "from_xml", "model", @@ -33,19 +24,4 @@ "to_xml", "ureg", "validate_xml", - "__version__", ] - - -def __getattr__(name: str) -> Any: - if name == "validate": - import warnings - - warnings.warn( - "'ome_types.validate' has been renamed to 'ome_types.validate_xml. " - "This will raise an exception in the future.", - FutureWarning, - stacklevel=2, - ) - return validate_xml - raise AttributeError("module {__name__!r} has no attribute {name!r}") diff --git a/src/ome_types/_base_type.py b/src/ome_types/_base_type.py deleted file mode 100644 index a40e439b..00000000 --- a/src/ome_types/_base_type.py +++ /dev/null @@ -1,169 +0,0 @@ -from typing import TYPE_CHECKING, Any, ClassVar, Dict, Optional, Sequence, Set - -from pydantic import BaseModel, validator - -if TYPE_CHECKING: - import pint - - -class Sentinel: - """Create singleton sentinel objects with a readable repr.""" - - def __init__(self, name: str) -> None: - self.name = name - - def __repr__(self) -> str: - return f"{__name__}.{self.name}.{id(self)}" - - -def quantity_property(field: str) -> property: - """Create property that returns a ``pint.Quantity`` combining value and unit.""" - - def quantity(self: Any) -> Optional["pint.Quantity"]: - from ome_types._units import ureg - - value = getattr(self, field) - if value is None: - return None - unit = getattr(self, f"{field}_unit").value.replace(" ", "_") - return ureg.Quantity(value, unit) - - return property(quantity) - - -class OMEType(BaseModel): - """The base class that all OME Types inherit from. - - This provides some global conveniences around auto-setting ids. (i.e., making them - optional in the class constructor, but never ``None`` after initialization.). - It provides a nice __repr__ that hides things that haven't been changed from - defaults. It adds ``*_quantity`` property for fields that have both a value and a - unit, where ``*_quantity`` is a pint ``Quantity``. It also provides pickling - support. - """ - - # Default value to support automatic numbering for id field values. - _AUTO_SEQUENCE = Sentinel("AUTO_SEQUENCE") - # allow use with weakref - __slots__: ClassVar[Set[str]] = {"__weakref__"} # type: ignore - - def __init__(__pydantic_self__, **data: Any) -> None: - if "id" in __pydantic_self__.__fields__: - data.setdefault("id", OMEType._AUTO_SEQUENCE) - super().__init__(**data) - - def __init_subclass__(cls) -> None: - """Add some properties to subclasses with units. - - It adds ``*_quantity`` property for fields that have both a value and a - unit, where ``*_quantity`` is a pint ``Quantity`` - """ - _clsdir = set(cls.__fields__) - for field in _clsdir: - if f"{field}_unit" in _clsdir: - setattr(cls, f"{field}_quantity", quantity_property(field)) - - # pydantic BaseModel configuration. - # see: https://pydantic-docs.helpmanual.io/usage/model_config/ - class Config: - # whether to allow arbitrary user types for fields (they are validated - # simply by checking if the value is an instance of the type). If - # False, RuntimeError will be raised on model declaration - arbitrary_types_allowed = False - # whether to perform validation on assignment to attributes - validate_assignment = True - # whether to treat any underscore non-class var attrs as private - # https://pydantic-docs.helpmanual.io/usage/models/#private-model-attributes - underscore_attrs_are_private = True - # whether to populate models with the value property of enums, rather - # than the raw enum. This may be useful if you want to serialise - # model.dict() later. False by default - # see conversation in https://github.com/tlambert03/ome-types/pull/74 - use_enum_values = False - # whether to validate field defaults (default: False) - validate_all = True - - def __repr__(self: Any) -> str: - from datetime import datetime - from enum import Enum - from textwrap import indent - - name = self.__class__.__qualname__ - lines = [] - for f in sorted( - self.__fields__.values(), key=lambda f: f.name not in ("name", "id") - ): - if f.name.endswith("_"): - continue - # https://github.com/python/mypy/issues/6910 - if f.default_factory: - default = f.default_factory() - else: - default = f.default - - current = getattr(self, f.name) - if current != default: - if isinstance(current, Sequence) and not isinstance(current, str): - rep = f"[<{len(current)} {f.name.title()}>]" - elif isinstance(current, Enum): - rep = repr(current.value) - elif isinstance(current, datetime): - rep = f"datetime.fromisoformat({current.isoformat()!r})" - else: - rep = repr(current) - lines.append(f"{f.name}={rep},") - if len(lines) == 1: - body = lines[-1].rstrip(",") - elif lines: - body = "\n" + indent("\n".join(lines), " ") + "\n" - else: - body = "" - out = f"{name}({body})" - return out - - @validator("id", pre=True, always=True, check_fields=False) - def validate_id(cls, value: Any) -> str: - """Pydantic validator for ID fields in OME models. - - If no value is provided, this validator provides and integer ID, and stores the - maximum previously-seen value on the class. - """ - from typing import ClassVar - - # get the required LSID field from the annotation - id_field = cls.__fields__.get("id") - if not id_field: - return value - - # Store the highest seen value on the class._max_id attribute. - if not hasattr(cls, "_max_id"): - cls._max_id = 0 # type: ignore [misc] - cls.__annotations__["_max_id"] = ClassVar[int] - if value is OMEType._AUTO_SEQUENCE: - value = cls._max_id + 1 - if isinstance(value, int): - v_id = value - id_string = id_field.type_.__name__[:-2] - value = f"{id_string}:{value}" - else: - value = str(value) - v_id = value.rsplit(":", 1)[-1] - try: - v_id = int(v_id) - cls._max_id = max(cls._max_id, v_id) # type: ignore [misc] - except ValueError: - pass - - return id_field.type_(value) - - def __getstate__(self: Any) -> Dict[str, Any]: - """Support pickle of our weakref references.""" - state = super().__getstate__() - state["__private_attribute_values__"].pop("_ref", None) - return state - - @classmethod - def snake_name(cls) -> str: - from .model import _camel_to_snake - - return _camel_to_snake[cls.__name__] diff --git a/src/ome_types/_constants.py b/src/ome_types/_constants.py deleted file mode 100644 index 00033e56..00000000 --- a/src/ome_types/_constants.py +++ /dev/null @@ -1,9 +0,0 @@ -from pathlib import Path - -URI_OME = "http://www.openmicroscopy.org/Schemas/OME/2016-06" -NS_OME = "{" + URI_OME + "}" -NS_XSI = "{http://www.w3.org/2001/XMLSchema-instance}" -SCHEMA_LOC_OME = f"{URI_OME} {URI_OME}/ome.xsd" - - -OME_2016_06_XSD = str(Path(__file__).parent / "ome-2016-06.xsd") diff --git a/src/ome_types/_convenience.py b/src/ome_types/_convenience.py deleted file mode 100644 index 7a981a79..00000000 --- a/src/ome_types/_convenience.py +++ /dev/null @@ -1,228 +0,0 @@ -import os -from pathlib import Path -from typing import TYPE_CHECKING, Any, Dict, Optional, Union, cast -from warnings import warn - -from typing_extensions import Protocol - -from .model import OME - -if TYPE_CHECKING: - from ._xmlschema import XMLSourceType - - -class Parser(Protocol): - # Used for type checks on xml parsers - def __call__( - self, path_or_str: Union[Path, str, bytes], validate: Optional[bool] = False - ) -> Dict[str, Any]: - ... - - -def to_dict( - xml: Union[Path, str, bytes], - *, - parser: Union[Parser, str, None] = None, - validate: Optional[bool] = None, -) -> Dict[str, Any]: - """Convert OME XML to dict. - - Parameters - ---------- - xml : Union[Path, str, bytes] - XML string or path to XML file. - parser : Union[Parser, str] - Either a parser callable with signature: - `(path_or_str: Union[Path, str, bytes], validate: Optional[bool] = False) -> - Dict`, or a string. If a string, must be either 'lxml' or 'xmlschema'. by - default "lxml" - validate : Optional[bool], optional - Whether to validate XML as valid OME XML, by default (`None`), the choices is - left to the parser (which is `False` for the lxml parser) - - Returns - ------- - Dict[str, Any] - OME model dict. - - Raises - ------ - KeyError - If `parser` is a string, and not one of `'lxml'` or `'xmlschema'` - """ - if parser is None: - warn( - "The default XML parser will be changing from 'xmlschema' to 'lxml' in " - "version 0.4.0. To silence this warning, please provide the `parser` " - "argument, specifying either 'lxml' (to opt into the new behavior), or" - "'xmlschema' (to retain the old behavior).", - FutureWarning, - stacklevel=2, - ) - parser = "xmlschema" - - if isinstance(parser, str): - if parser == "lxml": - from ._lxml import xml2dict - - parser = cast(Parser, xml2dict) - elif parser == "xmlschema": - from ._xmlschema import xmlschema2dict - - parser = cast(Parser, xmlschema2dict) - else: - raise KeyError("parser string must be one of {'lxml', 'xmlschema'}") - - if validate is True: - validate_xml(xml) - - d = parser(xml) - for key in list(d.keys()): - if key.startswith(("xml", "xsi")): - d.pop(key) - return d - - -def from_xml( - xml: Union[Path, str, bytes], - *, - parser: Union[Parser, str, None] = None, - validate: Optional[bool] = None, -) -> OME: - """Generate OME metadata object from XML string or path. - - Parameters - ---------- - xml : Union[Path, str, bytes] - XML string or path to XML file. - parser : Union[Parser, str] - Either a parser callable with signature: `(path_or_str: Union[Path, str, bytes], - validate: Optional[bool] = False) -> Dict`, or a string. If a string, must be - either 'lxml' or 'xmlschema'. by default "lxml" - validate : Optional[bool], optional - Whether to validate XML as valid OME XML, by default (`None`), the choices is - left to the parser (which is `False` for the lxml parser) - - - Returns - ------- - ome: ome_types.model.ome.OME - ome_types.OME metadata object - """ - d = to_dict(os.fspath(xml), parser=parser, validate=validate) - return OME(**d) - - -def from_tiff( - path: Union[Path, str], - *, - parser: Union[Parser, str, None] = None, - validate: Optional[bool] = True, -) -> OME: - """Generate OME metadata object from OME-TIFF path. - - This will use the first ImageDescription tag found in the TIFF header. - - Parameters - ---------- - path : Union[Path, str] - Path to OME TIFF. - parser : Union[Parser, str] - Either a parser callable with signature: `(path_or_str: Union[Path, str, bytes], - validate: Optional[bool] = False) -> Dict`, or a string. If a string, must be - either 'lxml' or 'xmlschema'. by default "lxml" - validate : Optional[bool], optional - Whether to validate XML as valid OME XML, by default (`None`), the choices is - left to the parser (which is `False` for the lxml parser) - - - Returns - ------- - ome: ome_types.model.ome.OME - ome_types.OME metadata object - - Raises - ------ - ValueError - If the TIFF file has no OME metadata. - """ - return from_xml(_tiff2xml(path), parser=parser, validate=validate) - - -def _tiff2xml(path: Union[Path, str]) -> bytes: - """Extract OME XML from OME-TIFF path. - - This will use the first ImageDescription tag found in the TIFF header. - - Parameters - ---------- - path : Union[Path, str] - Path to OME TIFF. - - Returns - ------- - xml : str - OME XML - - Raises - ------ - ValueError - If the TIFF file has no OME metadata. - """ - from struct import unpack - - with Path(path).open(mode="rb") as fh: - try: - offsetsize, offsetformat, tagnosize, tagnoformat, tagsize, codeformat = { - b"II*\0": (4, "I", 2, ">H", 12, ">H"), - b"II+\0": (8, "Q", 8, ">Q", 20, ">H"), - }[fh.read(4)] - except KeyError as e: - raise ValueError(f"{path!r} does not have a recognized TIFF header") from e - - fh.read(4 if offsetsize == 8 else 0) - fh.seek(unpack(offsetformat, fh.read(offsetsize))[0]) - for _ in range(unpack(tagnoformat, fh.read(tagnosize))[0]): - tagstruct = fh.read(tagsize) - if unpack(codeformat, tagstruct[:2])[0] == 270: - size = unpack(offsetformat, tagstruct[4 : 4 + offsetsize])[0] - if size <= offsetsize: - desc = tagstruct[4 + offsetsize : 4 + offsetsize + size] - break - fh.seek(unpack(offsetformat, tagstruct[-offsetsize:])[0]) - desc = fh.read(size) - break - else: - raise ValueError(f"No OME metadata found in file: {path}") - if desc[-1] == 0: - desc = desc[:-1] - return desc - - -def to_xml(ome: OME, **kwargs: Any) -> str: - """ - Dump an OME object to string. - - Parameters - ---------- - ome: OME - OME object to dump. - **kwargs - Extra kwargs to pass to ElementTree.tostring. - - Returns - ------- - ome_string: str - The XML string of the OME object. - """ - from ._xmlschema import to_xml - - return to_xml(ome, **kwargs) - - -def validate_xml(xml: "XMLSourceType", schema: Any = None) -> None: - from ._xmlschema import validate - - validate(xml, schema=schema) diff --git a/src/ome_types/_conversion.py b/src/ome_types/_conversion.py new file mode 100644 index 00000000..e3b63e94 --- /dev/null +++ b/src/ome_types/_conversion.py @@ -0,0 +1,220 @@ +from __future__ import annotations + +import importlib +import operator +import os +import warnings +from dataclasses import is_dataclass +from pathlib import Path +from struct import Struct +from typing import TYPE_CHECKING, Any, cast + +from ome_types.validation import validate_xml + +try: + from lxml import etree as ET +except ImportError: # pragma: no cover + from xml.etree import ElementTree as ET # type: ignore[no-redef] + +from xsdata.formats.dataclass.parsers.config import ParserConfig + +from xsdata_pydantic_basemodel.bindings import ( + SerializerConfig, + XmlParser, + XmlSerializer, +) + +if TYPE_CHECKING: + import io + from typing import TypedDict + + from xsdata.formats.dataclass.parsers.mixins import XmlHandler + + from ome_types._mixins._base_type import OMEType + from ome_types.model import OME + from xsdata_pydantic_basemodel.bindings import XmlContext + + class ParserKwargs(TypedDict, total=False): + config: ParserConfig + context: XmlContext + handler: type[XmlHandler] + + +__all__ = ["from_xml", "to_xml", "to_dict", "from_tiff", "tiff2xml"] + +OME_2016_06_URI = "http://www.openmicroscopy.org/Schemas/OME/2016-06" +MODULES = { + OME_2016_06_URI: "ome_types._autogenerated.ome_2016_06", +} + + +def _get_ome_type(xml: str | bytes) -> type[OMEType]: + """Resolve a python model class for the root element of an OME XML document.""" + if isinstance(xml, str) and not xml.startswith("<"): + root = ET.parse(xml).getroot() # noqa: S314 + else: + if not isinstance(xml, bytes): + xml = xml.encode("utf-8") + root = ET.fromstring(xml) # noqa: S314 + + *_ns, localname = root.tag[1:].split("}", 1) + ns = next(iter(_ns), None) + + if not ns or ns not in MODULES: + raise ValueError(f"Unsupported OME schema tag {root.tag!r} in namespace {ns!r}") + + mod = importlib.import_module(MODULES[ns]) + try: + return getattr(mod, localname) + except AttributeError as e: + raise ValueError( + f"Could not find a class for {localname!r} in {mod.__name__}" + ) from e + + +def to_dict(source: OME | Path | str | bytes) -> dict[str, Any]: + if is_dataclass(source): + raise NotImplementedError("dataclass -> dict is not supported yet") + + return from_xml( # type: ignore[return-value] + cast("Path | str | bytes", source), + # the class_factory is what prevents class instantiation, + # simply returning the params instead + parser_kwargs={"config": ParserConfig(class_factory=lambda a, b: b)}, + ) + + +def from_xml( + xml: Path | str | bytes, + *, + validate: bool | None = None, + parser: Any = None, + parser_kwargs: ParserKwargs | None = None, +) -> OME: + if parser is not None: + warnings.warn( + "As of version 0.4.0, the parser argument is ignored. " + "lxml will be used if available in the environment, but you can " + "drop this keyword argument.", + DeprecationWarning, + stacklevel=2, + ) + + if validate: + validate_xml(xml) + + if isinstance(xml, Path): + xml = str(xml) + + # this cast is a lie... but it's by far the most common type that will + # come out of this function, and will be more useful to most users. + # For those who pass in an xml document that isn't just a root tag, + # they can cast the result to the correct type themselves. + OME_type = cast("type[OME]", _get_ome_type(xml)) + + parser_ = XmlParser(**(parser_kwargs or {})) + if isinstance(xml, bytes): + return parser_.from_bytes(xml, OME_type) + if os.path.isfile(xml): + return parser_.parse(xml, OME_type) + return parser_.from_string(xml, OME_type) + + +def to_xml( + ome: OME, + *, + # exclude_defaults takes precendence over exclude_unset + # if a value equals the default, it will be excluded + exclude_defaults: bool = False, + # exclude_unset will exclude any value that is not explicitly set + # but will INCLUDE values that are set to their default + exclude_unset: bool = True, + indent: int = 2, + include_namespace: bool | None = None, + include_schema_location: bool = True, + canonicalize: bool = False, + validate: bool = False, +) -> str: + config = SerializerConfig( + pretty_print=(indent > 0) and not canonicalize, # canonicalize does it for us + pretty_print_indent=" " * indent, + xml_declaration=False, + ignore_default_attributes=exclude_defaults, + ignore_unset_attributes=exclude_unset, + attribute_sort_key=operator.attrgetter("name") if canonicalize else None, + ) + if include_schema_location: + config.schema_location = f"{OME_2016_06_URI} {OME_2016_06_URI}/ome.xsd" + + serializer = XmlSerializer(config=config) + if include_namespace is None: + include_namespace = canonicalize + + ns_map = {"ome" if include_namespace else None: OME_2016_06_URI} + xml = serializer.render(ome, ns_map=ns_map) + + if canonicalize: + xml = _canonicalize(xml, indent=" " * indent) + if validate: + validate_xml(xml) + return xml + + +def _canonicalize(xml: str, indent: str) -> str: + from xml.dom import minidom + + xml_out = ET.canonicalize(xml, strip_text=True) + return minidom.parseString(xml_out).toprettyxml(indent=indent) # noqa: S318 + + +def from_tiff( + path: Path | str, + *, + validate: bool | None = None, + parser_kwargs: ParserKwargs | None = None, +) -> OME: + xml = tiff2xml(path) + return from_xml(xml, validate=validate, parser_kwargs=parser_kwargs) + + +TIFF_TYPES: dict[bytes, tuple[Struct, Struct, int, Struct]] = { + b"II*\0": (Struct("I"), Struct(">H"), 12, Struct(">H")), + b"II+\0": (Struct("Q"), Struct(">Q"), 20, Struct(">H")), +} + + +def _unpack(fh: io.BufferedReader, strct: Struct) -> int: + return strct.unpack(fh.read(strct.size))[0] + + +def tiff2xml(path: Path | str) -> bytes: + """Extract the OME-XML from a TIFF file.""" + with Path(path).open(mode="rb") as fh: + head = fh.read(4) + if head not in TIFF_TYPES: + raise ValueError(f"{path!r} does not have a recognized TIFF header") + + offset_fmt, tagno_fmt, tagsize, codeformat = TIFF_TYPES[head] + offset_size = offset_fmt.size + offset_size_4 = offset_size + 4 + + if offset_size == 8: + fh.seek(4, 1) + fh.seek(_unpack(fh, offset_fmt)) + for _ in range(_unpack(fh, tagno_fmt)): + tagstruct = fh.read(tagsize) + if codeformat.unpack(tagstruct[:2])[0] == 270: + size = offset_fmt.unpack(tagstruct[4:offset_size_4])[0] + if size <= offset_size: + desc = tagstruct[offset_size_4 : offset_size_4 + size] + break + fh.seek(offset_fmt.unpack(tagstruct[-offset_size:])[0]) + desc = fh.read(size) + break + else: + raise ValueError(f"No OME metadata found in file: {path}") + if desc[-1] == 0: + desc = desc[:-1] + return desc diff --git a/src/ome_types/_lxml.py b/src/ome_types/_lxml.py deleted file mode 100644 index 645409c9..00000000 --- a/src/ome_types/_lxml.py +++ /dev/null @@ -1,229 +0,0 @@ -from __future__ import annotations - -from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, Container, MutableSequence, Union, cast - -from typing_extensions import get_args - -from . import model -from ._constants import OME_2016_06_XSD, URI_OME -from .model.shape_group import ShapeGroupType -from .util import _ensure_xml_bytes, _get_plural, camel_to_snake, cast_number, norm_key - -NEED_INT = [s.__name__ for s in get_args(ShapeGroupType)] -NEED_INT.extend(["Channel", "Well"]) - - -def _is_xml_comment(element: Element) -> bool: - return False - - -if TYPE_CHECKING: - import xml.etree.ElementTree - - import lxml.etree - - Element = Union[xml.etree.ElementTree.Element, lxml.etree._Element] - ElementTree = Union[xml.etree.ElementTree.ElementTree, lxml.etree._ElementTree] - Value = Union[float, str, int, bytearray, bool, dict[str, Any], list[Any]] - Parser = Callable[[bytes], Element] - - XML: Parser - tostring: Callable[[Element], bytes] - parse: Callable[[str], ElementTree] - -else: - try: - # faster if it's available - from lxml.etree import XML, _Comment, parse, tostring - - def _is_xml_comment(element: Element) -> bool: - return isinstance(element, _Comment) - - except ImportError: - from xml.etree.ElementTree import XML, parse, tostring - - -def elem2dict(node: Element, exclude_null: bool = True) -> dict[str, Any]: - """Convert an xml.etree or lxml.etree Element into a dict. - - Parameters - ---------- - node : Element - The Element to convert. Should be an `xml.etree.ElementTree.Element` or a - `lxml.etree._Element` - exclude_null : bool, optional - If True, exclude keys with null values from the output. - - - Returns - ------- - dict[str, Any] - The converted Element. - """ - result: dict[str, Any] = {} - - # Re-used valued - norm_node = norm_key(node.tag) - # set of keys that are lists - norm_list: Container = model._lists.get(norm_node, {}) - - # Process attributes - for key, val in node.attrib.items(): - is_list = key in norm_list - key = camel_to_snake(norm_key(key)) - if norm_node in NEED_INT: - val = cast_number(val) - if is_list: - key = _get_plural(key, node.tag) - if key not in result: - result[key] = [] - cast("list", result[key]).extend(val.split()) - else: - result[key] = val - - # Process children - for element in node: - element = cast("Element", element) - if _is_xml_comment(element): - continue - key = norm_key(element.tag) - - # Process element as tree element if inner XML contains non-whitespace content - if element.text and element.text.strip(): - value: Any = element.text - if element.attrib.items(): - value = {"value": value} - for k, val in element.attrib.items(): - value[camel_to_snake(norm_key(k))] = val - - elif key == "MetadataOnly": - value = True - - else: - value = elem2dict(element, exclude_null=exclude_null) - if key == "XMLAnnotation": - value["value"] = tostring(element[0]) - - is_list = key in norm_list - key = camel_to_snake(key) - if is_list: - if key == "bin_data" and value["length"] == "0" and "value" not in value: - value["value"] = "" - key = _get_plural(key, node.tag) - if key not in result: - result[key] = [] - cast("list", result[key]).append(value) - - elif key == "structured_annotations": - annotations = [] - for _type in ( - "boolean_annotation", - "comment_annotation", - "double_annotation", - "file_annotation", - "list_annotation", - "long_annotation", - "map_annotation", - "tag_annotation", - "term_annotation", - "timestamp_annotation", - "xml_annotation", - ): - if _type in value: - values = value.pop(_type) - if not isinstance(values, list): - values = [values] - - for v in values: - v["_type"] = _type - - # Normalize empty element to zero-length string. - if "value" not in v or v["value"] is None: - v["value"] = "" - annotations.extend(values) - - if key in result: - raise ValueError("Duplicate structured_annotations") - result[key] = annotations - - elif value or not exclude_null: - try: - rv = result[key] - except KeyError: - result[key] = [value] if key.lower() == "m" else value - else: - if not isinstance(rv, MutableSequence) or not rv: - result[key] = [rv, value] - elif isinstance(rv[0], MutableSequence) or not isinstance( - value, MutableSequence - ): - rv.append(value) - else: - result[key] = [result, value] - - return result - - -def validate_lxml(node: Element) -> Element: - """Ensure that `node` is valid OMX XML. - - Raises - ------ - lxml.etree.XMLSchemaValidateError - If `node` is not valid OME XML. - """ - # TODO: unify with xmlschema validate - try: - import lxml.etree - except ImportError as e: - raise ImportError("validating xml requires lxml") from e - - for key, val in node.attrib.items(): - if "schemaLocation" in key: - ns, uri = val.split() - if ns == URI_OME: - uri = OME_2016_06_XSD - schema_doc = parse(uri) # noqa: S314 - break - if not lxml.etree.XMLSchema(schema_doc).validate(node): - raise lxml.etree.XMLSchemaValidateError( - f"XML did not pass validation error against {uri}" - ) - return node - - -def xml2dict( - path_or_str: Path | str | bytes, validate: bool | None = False -) -> dict[str, Any]: - """Convert XML string or path to dict using lxml. - - xml : Union[Path, str, bytes] - XML string, bytes, or path to XML file. - validate : Optional[bool] - Whether to validate XML as valid OME XML, by default False. - - Returns - ------- - Dict[str, Any] - OME dict - """ - root = XML(_ensure_xml_bytes(path_or_str)) - if validate: - validate_lxml(root) - - return elem2dict(root) - - -def __getattr__(name: str) -> Any: - """Import lxml if it is not already imported.""" - if name == "lxml2dict": - import warnings - - warnings.warn( - "lxml2dict is deprecated, use xml2dict instead", - FutureWarning, - stacklevel=2, - ) - return xml2dict - raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/src/ome_types/_mixins/__init__.py b/src/ome_types/_mixins/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/ome_types/_mixins/_base_type.py b/src/ome_types/_mixins/_base_type.py new file mode 100644 index 00000000..aa774a24 --- /dev/null +++ b/src/ome_types/_mixins/_base_type.py @@ -0,0 +1,144 @@ +import warnings +from datetime import datetime +from enum import Enum +from textwrap import indent +from typing import TYPE_CHECKING, Any, ClassVar, Optional, Sequence, Set, Tuple, cast + +from pydantic import BaseModel, validator + +from ome_types._mixins._ids import validate_id +from ome_types.units import ureg + +if TYPE_CHECKING: + import pint + + +# Default value to support automatic numbering for id field values. +AUTO_SEQUENCE = "__auto_sequence__" + + +_UNIT_FIELD = "{}_unit" +_QUANTITY_FIELD = "{}_quantity" +DEPRECATED_NAMES = { + "annotation_ref": "annotation_refs", + "bin_data": "bin_data_blocks", + "dataset_ref": "dataset_refs", + "emission_filter_ref": "emission_filters", + "excitation_filter_ref": "excitation_filters", + "experimenter_ref": "experimenter_refs", + "folder_ref": "folder_refs", + "image_ref": "image_refs", + "leader": "leaders", + "light_source_settings": "light_source_settings_combinations", + "m": "ms", + "microbeam_manipulation_ref": "microbeam_manipulation_refs", + "plate_ref": "plate_refs", + "roi_ref": "roi_refs", + "well_sample_ref": "well_sample_refs", +} + + +class OMEType(BaseModel): + """The base class that all OME Types inherit from. + + This provides some global conveniences around auto-setting ids. (i.e., making them + optional in the class constructor, but never ``None`` after initialization.). + It provides a nice __repr__ that hides things that haven't been changed from + defaults. It adds ``*_quantity`` property for fields that have both a value and a + unit, where ``*_quantity`` is a pint ``Quantity``. It also provides pickling + support. + """ + + # pydantic BaseModel configuration. + # see: https://pydantic-docs.helpmanual.io/usage/model_config/ + class Config: + arbitrary_types_allowed = False + validate_assignment = True + underscore_attrs_are_private = True + use_enum_values = False + validate_all = True + + # allow use with weakref + __slots__: ClassVar[Set[str]] = {"__weakref__"} # type: ignore + + _v = validator("id", pre=True, always=True, check_fields=False)(validate_id) + + def __init__(self, **data: Any) -> None: + super().__init__(**data) + field_names = set(self.__fields__.keys()) + kwargs = set(data.keys()) + if kwargs - field_names: + warnings.warn(f"Unrecognized fields: {kwargs - field_names}", stacklevel=2) + + def __init_subclass__(cls) -> None: + """Add `*_quantity` property for fields that have both a value and a unit. + + where `*_quantity` is a pint `Quantity`. + """ + for field in cls.__fields__: + if _UNIT_FIELD.format(field) in cls.__fields__: + setattr(cls, _QUANTITY_FIELD.format(field), _quantity_property(field)) + + def __repr_args__(self) -> Sequence[Tuple[Optional[str], Any]]: + """Repr with only set values, and truncated sequences.""" + args = [] + for k, v in self._iter(exclude_defaults=True): + if isinstance(v, Sequence) and not isinstance(v, str): + # if this is a sequence with a long repr, just show the length + # and type + if len(repr(v).split(",")) > 5: + type_name = self.__fields__[k].type_.__name__ + v = _RawRepr(f"[<{len(v)} {type_name}>]") + elif isinstance(v, Enum): + v = v.value + elif isinstance(v, datetime): + v = v.isoformat() + args.append((k, v)) + return sorted(args, key=lambda f: f[0] not in ("name", "id")) + + def __repr__(self) -> str: + lines = [f"{key}={val!r}," for key, val in self.__repr_args__()] + if len(lines) == 1: + body = lines[-1].rstrip(",") + elif lines: + body = "\n" + indent("\n".join(lines), " ") + "\n" + else: + body = "" + return f"{self.__class__.__qualname__}({body})" + + def __getattr__(self, key: str) -> Any: + """Getattr that redirects deprecated names.""" + cls_name = self.__class__.__name__ + if key in DEPRECATED_NAMES and hasattr(self, DEPRECATED_NAMES[key]): + new_key = DEPRECATED_NAMES[key] + warnings.warn( + f"Attribute '{cls_name}.{key}' is deprecated, use {new_key!r} instead", + DeprecationWarning, + stacklevel=2, + ) + return getattr(self, new_key) + raise AttributeError(f"{cls_name} object has no attribute {key!r}") + + +def _quantity_property(field_name: str) -> property: + """Create property that returns a ``pint.Quantity`` combining value and unit.""" + + def quantity(self: Any) -> Optional["pint.Quantity"]: + value = getattr(self, field_name) + if value is None: + return None + + unit = cast("Enum", getattr(self, _UNIT_FIELD.format(field_name))) + return ureg.Quantity(value, unit.value.replace(" ", "_")) + + return property(quantity) + + +class _RawRepr: + """Helper class to allow repr to show raw values for fields that are sequences.""" + + def __init__(self, raw: str) -> None: + self.raw = raw + + def __repr__(self) -> str: + return self.raw diff --git a/src/ome_types/_mixins/_bin_data.py b/src/ome_types/_mixins/_bin_data.py new file mode 100644 index 00000000..7ddffd18 --- /dev/null +++ b/src/ome_types/_mixins/_bin_data.py @@ -0,0 +1,22 @@ +import warnings +from typing import Any, Dict + +from pydantic import root_validator + +from ome_types._mixins._base_type import OMEType + + +class BinDataMixin(OMEType): + @root_validator(pre=True) + def _v(cls, values: dict) -> Dict[str, Any]: + # This catches the case of , where the parser may have + # omitted value from the dict, and sets value to b"" + # seems like it could be done in a default_factory, but that would + # require more modification of xsdata I think + if "value" not in values: + if values.get("length") != 0: + warnings.warn( + "BinData length is non-zero but value is missing", stacklevel=2 + ) + values["value"] = b"" + return values diff --git a/src/ome_types/_mixins/_ids.py b/src/ome_types/_mixins/_ids.py new file mode 100644 index 00000000..19af2c73 --- /dev/null +++ b/src/ome_types/_mixins/_ids.py @@ -0,0 +1,85 @@ +from __future__ import annotations + +import re +import warnings +from contextlib import suppress +from typing import TYPE_CHECKING, Any, cast + +if TYPE_CHECKING: + from pydantic import BaseModel + from typing_extensions import Final + +# Default value to support automatic numbering for id field values. +AUTO_SEQUENCE: Final = "__auto_sequence__" +# map of id_name -> max id value +ID_COUNTER: dict[str, int] = {} + +# map of (id_name, id_value) -> converted id +# NOTE: this is cleared in OMEMixin.__init__, so that the set of converted IDs +# is unique to each OME instance +CONVERTED_IDS: dict[tuple[str, str], str] = {} + + +def _get_id_name_and_pattern(cls: type[BaseModel]) -> tuple[str, str]: + # let this raise if it doesn't exist... + # this should only be used on classes that have an id field + id_field = cls.__fields__["id"] + id_pattern = cast(str, id_field.field_info.regex) + id_name = id_pattern.split(":")[-3] + + return id_name, id_pattern + + +def validate_id(cls: type[BaseModel], value: int | str) -> Any: + """Pydantic validator for ID fields in OME models. + + This validator does the following: + 1. if it's valid string ID just use it, and updating the counter if necessary. + 2. if it's an invalid string id, try to extract the integer part from it, and use + that to create a new ID, or use the next value in the sequence if not. + 2. if it's an integer, grab the appropriate ID name from the pattern and prepend it. + 3. if it's the special `AUTO_SEQUENCE` sentinel, use the next value in the sequence. + + COUNTERS stores the maximum previously-seen value on the class. + """ + id_name, id_pattern = _get_id_name_and_pattern(cls) + current_count = ID_COUNTER.setdefault(id_name, -1) + + if value == AUTO_SEQUENCE: + # if it's the special sentinel, use the next value + value = ID_COUNTER[id_name] + 1 + elif isinstance(value, str): + if (id_name, value) in CONVERTED_IDS: + # XXX: possible bug + # if the same invalid value is used across multiple documents + # we'll be replacing it with the same converted id here + return CONVERTED_IDS[(id_name, value)] + + # if the value is a string, extract the number from it if possible + value_id: str = value.rsplit(":", 1)[-1] + + # if the value matches the pattern, just return it + # but update the counter if it's higher than the current value + if re.match(id_pattern, value): + with suppress(ValueError): + # (not all IDs have integers after the colon) + ID_COUNTER[id_name] = max(current_count, int(value_id)) + return value + + # if the value doesn't match the pattern, create a proper ID + # (using the value_id as the integer part if possible) + id_int = int(value_id) if value_id.isdecimal() else current_count + 1 + newname = validate_id(cls, id_int) + # store the converted ID so we can use it elsewhere + CONVERTED_IDS[(id_name, value)] = newname + + # warn the user + msg = f"Casting invalid {id_name}ID {value!r} to {newname!r}" + warnings.warn(msg, stacklevel=2) + return newname + elif not isinstance(value, int): + raise ValueError(f"Invalid ID value: {value!r}, {type(value)}") + + # update the counter to be at least this value + ID_COUNTER[id_name] = max(current_count, value) + return f"{id_name}:{value}" diff --git a/src/ome_types/_mixins/_instrument.py b/src/ome_types/_mixins/_instrument.py new file mode 100644 index 00000000..019aa4ad --- /dev/null +++ b/src/ome_types/_mixins/_instrument.py @@ -0,0 +1,28 @@ +from typing import TYPE_CHECKING, List, Union, cast + +if TYPE_CHECKING: + from ome_types._autogenerated.ome_2016_06 import ( + Arc, + Filament, + GenericExcitationSource, + Instrument, + Laser, + LightEmittingDiode, + ) + + LightSource = Union[ + GenericExcitationSource, LightEmittingDiode, Filament, Arc, Laser + ] + + +class InstrumentMixin: + @property + def light_source_group(self) -> List["LightSource"]: + slf = cast("Instrument", self) + return [ + *slf.arcs, + *slf.filaments, + *slf.generic_excitation_sources, + *slf.lasers, + *slf.light_emitting_diodes, + ] diff --git a/src/ome_types/_mixins/_ome.py b/src/ome_types/_mixins/_ome.py new file mode 100644 index 00000000..8e538ce0 --- /dev/null +++ b/src/ome_types/_mixins/_ome.py @@ -0,0 +1,100 @@ +from __future__ import annotations + +import warnings +import weakref +from typing import TYPE_CHECKING, Any, cast + +from ome_types._mixins._base_type import OMEType +from ome_types._mixins._ids import CONVERTED_IDS + +if TYPE_CHECKING: + from pathlib import Path + + from ome_types._autogenerated.ome_2016_06 import OME, Reference + + +class OMEMixin: + def __init__(self, **data: Any) -> None: + # Clear the cache of converted IDs, so that they are unique to each OME instance + CONVERTED_IDS.clear() + super().__init__(**data) + self._link_refs() + + def _link_refs(self) -> None: + ids = collect_ids(self) + for ref in collect_references(self): + # all reference subclasses do actually have an 'id' field + # but it's not declared in the base class + if ref.id in ids: + ref._ref = weakref.ref(ids[ref.id]) + else: + warnings.warn(f"Reference to unknown ID: {ref.id}", stacklevel=2) + + def __setstate__(self, state: dict[str, Any]) -> None: + """Support unpickle of our weakref references.""" + super().__setstate__(state) # type: ignore + self._link_refs() + + @classmethod + def from_xml(cls, xml: Path | str) -> OME: + from ome_types._conversion import from_xml + + return from_xml(xml) + + @classmethod + def from_tiff(cls, path: Path | str) -> OME: + from ome_types._conversion import from_tiff + + return from_tiff(path) + + def to_xml(self, **kwargs: Any) -> str: + from ome_types._conversion import to_xml + + return to_xml(cast("OME", self), **kwargs) + + +def collect_ids(value: Any) -> dict[str, OMEType]: + """Return a map of all model objects contained in value, keyed by id. + + Recursively walks all dataclass fields and iterates over lists. The base + case is when value is neither a dataclass nor a list. + """ + from ome_types.model import Reference + + ids: dict[str, OMEType] = {} + if isinstance(value, list): + for v in value: + ids.update(collect_ids(v)) + elif isinstance(value, OMEType): + for f in value.__fields__: + if f == "id" and not isinstance(value, Reference): + # We don't need to recurse on the id string, so just record it + # and move on. + ids[value.id] = value + else: + ids.update(collect_ids(getattr(value, f))) + # Do nothing for uninteresting types. + return ids + + +def collect_references(value: Any) -> list[Reference]: + """Return a list of all References contained in value. + + Recursively walks all dataclass fields and iterates over lists. The base + case is when value is either a Reference object, or an uninteresting type + that we don't need to inspect further. + + """ + from ome_types.model import Reference + + references: list[Reference] = [] + if isinstance(value, Reference): + references.append(value) + elif isinstance(value, list): + for v in value: + references.extend(collect_references(v)) + elif isinstance(value, OMEType): + for f in value.__fields__: + references.extend(collect_references(getattr(value, f))) + # Do nothing for uninteresting types + return references diff --git a/src/ome_types/_mixins/_pixels.py b/src/ome_types/_mixins/_pixels.py new file mode 100644 index 00000000..ff8892b0 --- /dev/null +++ b/src/ome_types/_mixins/_pixels.py @@ -0,0 +1,19 @@ +from pydantic import root_validator + +from ome_types._mixins._base_type import OMEType + + +class PixelsMixin(OMEType): + @root_validator(pre=True) + def _validate_root(cls, values: dict) -> dict: + if "metadata_only" in values: + if isinstance(values["metadata_only"], bool): + if not values["metadata_only"]: + values.pop("metadata_only") + else: + # type ignore in case the autogeneration hasn't been built + from ome_types.model import MetadataOnly # type: ignore + + values["metadata_only"] = MetadataOnly() + + return values diff --git a/src/ome_types/_mixins/_reference.py b/src/ome_types/_mixins/_reference.py new file mode 100644 index 00000000..97ee19eb --- /dev/null +++ b/src/ome_types/_mixins/_reference.py @@ -0,0 +1,20 @@ +import weakref +from typing import Any, Dict, Optional, Union + +from ome_types._mixins._base_type import OMEType + + +class ReferenceMixin(OMEType): + _ref: Optional[weakref.ReferenceType] = None + + @property + def ref(self) -> Union[OMEType, None]: + if self._ref is None: + raise ValueError("references not yet resolved on root OME object") + return self._ref() + + def __getstate__(self: Any) -> Dict[str, Any]: + """Support pickle of our weakref references.""" + state = super().__getstate__() + state["__private_attribute_values__"].pop("_ref", None) + return state diff --git a/src/ome_types/_mixins/_util.py b/src/ome_types/_mixins/_util.py new file mode 100644 index 00000000..94082690 --- /dev/null +++ b/src/ome_types/_mixins/_util.py @@ -0,0 +1,6 @@ +import uuid + + +def new_uuid() -> str: + """Generate a new UUID.""" + return f"urn:uuid:{uuid.uuid4()}" diff --git a/src/ome_types/_napari_plugin.py b/src/ome_types/_napari_plugin.py deleted file mode 100644 index a0bc26cb..00000000 --- a/src/ome_types/_napari_plugin.py +++ /dev/null @@ -1,32 +0,0 @@ -from napari_plugin_engine import napari_hook_implementation - -from .widgets import OMETree - -METADATA_KEY = "ome_types" - - -@napari_hook_implementation -def napari_experimental_provide_dock_widget(): - return OMETree, {"name": "OME Metadata Viewer"} - - -@napari_hook_implementation -def napari_get_reader(path): - """Show OME XML if an ome.xml file is dropped on the viewer.""" - if isinstance(path, str) and path.endswith("ome.xml"): - return view_ome_xml - - -def view_ome_xml(path): - from napari._qt.qt_main_window import _QtMainWindow - - # close your eyes, or look away... - # there is nothing worth looking at here! - window = _QtMainWindow.current() - if not window: - return - viewer = window.qt_viewer.viewer - dw, widget = viewer.window.add_plugin_dock_widget("ome-types") - widget.update(path) - - return [(None,)] # sentinel diff --git a/src/ome_types/_xmlschema.py b/src/ome_types/_xmlschema.py deleted file mode 100644 index 49b7836b..00000000 --- a/src/ome_types/_xmlschema.py +++ /dev/null @@ -1,305 +0,0 @@ -from __future__ import annotations - -import os.path -from collections import defaultdict -from datetime import datetime -from enum import Enum -from functools import lru_cache -from pathlib import Path -from typing import IO, Any, Union -from xml.etree import ElementTree - -import xmlschema -from xmlschema import ElementData, XMLSchemaParseError -from xmlschema.converters import XMLSchemaConverter -from xmlschema.exceptions import XMLSchemaValueError - -from ome_types._base_type import OMEType - -from . import util -from ._constants import NS_OME, NS_XSI, OME_2016_06_XSD, SCHEMA_LOC_OME, URI_OME -from .model import ( - OME, - XMLAnnotation, - _camel_to_snake, - _plural_to_singular, - _singular_to_plural, - _snake_to_camel, - simple_types, -) - -__cache__: dict[str, xmlschema.XMLSchema] = {} -_XMLSCHEMA_VERSION: tuple[int, ...] = tuple( - int(v) if v.isnumeric() else v for v in xmlschema.__version__.split(".") -) - -XMLSourceType = Union[str, bytes, Path, IO[str], IO[bytes]] - - -@lru_cache(maxsize=8) -def _build_schema(ns: str, uri: str | None = None) -> xmlschema.XMLSchema: - """Return Schema object for a url. - - For the special case of retrieving the 2016-06 OME Schema, use local file. - """ - if ns == URI_OME: - schema = xmlschema.XMLSchema(OME_2016_06_XSD) - # FIXME Hack to work around xmlschema poor support for keyrefs to - # substitution groups - ls_sgs = schema.maps.substitution_groups[f"{NS_OME}LightSourceGroup"] - ls_id_maps = schema.maps.identities[f"{NS_OME}LightSourceIDKey"] - ls_id_maps.elements = {e: None for e in ls_sgs} - else: - schema = xmlschema.XMLSchema(uri) - return schema - - -def get_schema(source: xmlschema.XMLResource | XMLSourceType) -> xmlschema.XMLSchema: - """Fetch an XMLSchema object given XML source. - - Parameters - ---------- - source : XMLResource or str - can be an :class:`xmlschema.XMLResource` instance, a file-like object, a path - to a file or an URI of a resource or an Element instance or an ElementTree - instance or a string containing the XML data. - - Returns - ------- - xmlschema.XMLSchema - An XMLSchema object for the source - """ - if not isinstance(source, xmlschema.XMLResource): - source = xmlschema.XMLResource(source) - - for ns, uri in source.get_locations(): - try: - return _build_schema(ns, uri) - except XMLSchemaParseError: - pass - raise XMLSchemaValueError(f"Could not find a schema for XML resource {source!r}.") - - -def validate(xml: XMLSourceType, schema: xmlschema.XMLSchema | None = None) -> None: - schema = schema or get_schema(xml) - schema.validate(xml) - - -class OMEConverter(XMLSchemaConverter): - def __init__( - self, namespaces: dict[str, Any] | None = None, **kwargs: dict[Any, Any] - ): - self._ome_ns = "" - super().__init__(namespaces, attr_prefix="") - for name, uri in self._namespaces.items(): - if uri == URI_OME: - self._ome_ns = name - - def map_qname(self, qname: str) -> str: - name = super().map_qname(qname) - if name.lower().startswith(self._ome_ns): - name = name[len(self._ome_ns) :].lstrip(":") - return _camel_to_snake.get(name, name) - - def element_decode(self, data, xsd_element, xsd_type=None, level=0): # type: ignore - """Convert a decoded element data to a data structure.""" - result = super().element_decode(data, xsd_element, xsd_type, level) - if isinstance(result, dict) and "$" in result: - result["value"] = result.pop("$") - # FIXME: Work out a better way to deal with concrete extensions of - # abstract types. - if xsd_element.local_name == "MetadataOnly": - result = True - elif xsd_element.local_name == "BinData": - if result["length"] == 0 and "value" not in result: - result["value"] = "" - elif xsd_element.local_name == "StructuredAnnotations": - annotations = [] - for _type in ( - "boolean_annotation", - "comment_annotation", - "double_annotation", - "file_annotation", - "list_annotation", - "long_annotation", - "map_annotation", - "tag_annotation", - "term_annotation", - "timestamp_annotation", - "xml_annotation", - ): - if _type in result: - values = result.pop(_type) - for v in values: - v["_type"] = _type - # Normalize empty element to zero-length string. - if "value" in v and v["value"] is None: - v["value"] = "" - annotations.extend(values) - result = annotations - if isinstance(result, dict): - for name in list(result.keys()): - plural = _singular_to_plural.get((xsd_element.local_name, name), None) - if plural: - value = result.pop(name) - if not isinstance(value, list): - raise TypeError("expected list for plural attr") - result[plural] = value - return result - - def element_encode( - self, obj: Any, xsd_element: xmlschema.XsdElement, level: int = 0 - ) -> ElementData: - tag = xsd_element.qualified_name - if not isinstance(obj, OMEType): - if isinstance(obj, datetime): - return ElementData( - tag, obj.isoformat().replace("+00:00", "Z"), None, {} - ) - elif isinstance(obj, ElementTree.Element): - # ElementData can't represent mixed content, so we'll leave this - # element empty and fix it up after encoding is complete. - return ElementData(tag, None, None, {}) - elif xsd_element.type.simple_type is not None: - return ElementData(tag, obj, None, {}) - elif xsd_element.local_name == "MetadataOnly": - return ElementData(tag, None, None, {}) - elif xsd_element.local_name in {"Union", "StructuredAnnotations"}: - names = [type(v).__name__ for v in obj] - content = [(f"{NS_OME}{n}", v) for n, v in zip(names, obj)] - return ElementData(tag, None, content, {}) - else: - raise NotImplementedError( - "Encountered a combination of schema element and data type" - " that is not yet supported. Please submit a bug report with" - " the information below:" - f"\n element: {xsd_element}\n data type: {type(obj)}" - ) - text = None - content = [] - attributes = {} - # FIXME Can we simplify this? - tag_index = defaultdict( - lambda: -1, - ((_camel_to_snake[e.local_name], i) for i, e in enumerate(xsd_element)), - ) - _fields = obj.__fields__.values() - for field in sorted( - _fields, - key=lambda f: tag_index[_plural_to_singular.get(f.name, f.name)], - ): - name = field.name - if name.endswith("_"): - continue - default = ( - field.default_factory() if field.default_factory else field.default - ) - value = getattr(obj, name) - if value == default or name == "metadata_only" and not value: - continue - if isinstance(value, simple_types.Color): - value = value.as_int32() - name = _plural_to_singular.get(name, name) - name = _snake_to_camel.get(name, name) - if name in xsd_element.attributes: - if isinstance(value, list): - value = [getattr(i, "value", i) for i in value] - elif isinstance(value, Enum): - value = value.value - elif isinstance(value, datetime): - value = value.isoformat().replace("+00:00", "Z") - attributes[name] = value - elif name == "Value" and xsd_element.local_name in {"BinData", "UUID", "M"}: - text = value - else: - if not isinstance(value, list) or name in { - "Union", - "StructuredAnnotations", - }: - value = [value] - if name == "LightSourceGroup": - names = [type(v).__name__ for v in value] - else: - names = [name] * len(value) - content.extend([(f"{NS_OME}{n}", v) for n, v in zip(names, value)]) - return ElementData(tag, text, content, attributes) - - -def xmlschema2dict( - xml: str, - schema: xmlschema.XMLSchema | None = None, - converter: XMLSchemaConverter = OMEConverter, - validate: bool = False, - **kwargs: Any, -) -> dict[str, Any]: - if isinstance(xml, bytes): - xml = xml.decode("utf-8") - - schema = schema or get_schema(xml) - - if _XMLSCHEMA_VERSION >= (2,): - kwargs["validation"] = "strict" if validate else "lax" - - result = xmlschema.to_dict(xml, schema=schema, converter=converter, **kwargs) - - if _XMLSCHEMA_VERSION >= (2,) and not validate: - result, _ = result - # xmlschema doesn't provide usable access to mixed XML content, so we'll - # fill the XMLAnnotation value attributes ourselves by re-parsing the XML - # with ElementTree and using the Element objects as the values. - tree = None - for annotation in result.get("structured_annotations", []): - if annotation["_type"] == "xml_annotation": - if tree is None: - from io import StringIO - - # determine if we're dealing with a raw XML string or a filepath - # very long XML strings will raise ValueError on Windows. - try: - _xml = xml if os.path.exists(xml) else StringIO(xml) - except ValueError: - _xml = StringIO(xml) - - tree = ElementTree.parse(_xml) # type: ignore # noqa: S314 - aid = annotation["id"] - elt = tree.find(f".//{NS_OME}XMLAnnotation[@ID='{aid}']/{NS_OME}Value") - annotation["value"] = elt - return result - - -def to_xml_element(ome: OME) -> ElementTree.Element: - schema = _build_schema(URI_OME) - root = schema.encode( - ome, path=f"/{NS_OME}OME", converter=OMEConverter, use_defaults=False - ) - # Patch up the XML element tree with Element objects from XMLAnnotations to - # work around xmlschema's lack of support for mixed content. - for oid, obj in util.collect_ids(ome).items(): - if isinstance(obj, XMLAnnotation): - elt = root.find(f".//{NS_OME}XMLAnnotation[@ID='{oid}']/{NS_OME}Value") - elt.extend(list(obj.value)) - root.attrib[f"{NS_XSI}schemaLocation"] = SCHEMA_LOC_OME - return root - - -def to_xml(ome: OME, **kwargs: Any) -> str: - """ - Dump an OME object to string. - - Parameters - ---------- - ome: OME - OME object to dump. - **kwargs - Extra kwargs to pass to ElementTree.tostring. - - Returns - ------- - ome_string: str - The XML string of the OME object. - """ - root = to_xml_element(ome) - ElementTree.register_namespace("", URI_OME) - kwargs.setdefault("encoding", "unicode") - kwargs.setdefault("method", "xml") - return ElementTree.tostring(root, **kwargs) diff --git a/src/ome_types/model/__init__.py b/src/ome_types/model/__init__.py new file mode 100644 index 00000000..7ebec954 --- /dev/null +++ b/src/ome_types/model/__init__.py @@ -0,0 +1,78 @@ +from __future__ import annotations + +import importlib.util +import sys +from importlib.abc import Loader, MetaPathFinder +from pathlib import Path +from typing import TYPE_CHECKING, Sequence + +from ome_types._autogenerated.ome_2016_06 import * # noqa + +# these are here mostly to make mypy happy in pre-commit +# even when the model isn't built +from ome_types._autogenerated.ome_2016_06 import OME as OME +from ome_types._autogenerated.ome_2016_06 import Reference as Reference + +if TYPE_CHECKING: + from importlib.machinery import ModuleSpec + from types import ModuleType + +# --------------------------------------------------------------------- +# Below here is logic to allow importing from ome_types._autogenerated.ome_2016_06 +# from ome_types.model.* (to preserve backwards compatibility) +# i.e. importing ome_types.model.map will import from +# ome_types._autogenerated.ome_2016_06.map ... and emit a warning + + +_OME_2016 = Path(__file__).parent.parent / "_autogenerated" / "ome_2016_06" + + +class OME2016Loader(Loader): + def __init__(self, fullname: str) -> None: + submodule = fullname.split(".", 2)[-1] + file_2016 = (_OME_2016 / submodule.replace(".", "/")).with_suffix(".py") + module_2016 = fullname.replace(".model.", "._autogenerated.ome_2016_06.", 1) + if not file_2016.exists(): # pragma: no cover + raise ImportError( + f"Cannot find {submodule!r} in ome_types._autogenerated.ome_2016_06" + ) + + # warnings.warn( + # "Importing submodules from ome_types.model is deprecated. " + # "Please import types directly from ome_types.model instead.", + # stacklevel=2, + # ) + self.fullname = fullname + self.module_2016 = module_2016 + + def create_module(self, spec: ModuleSpec) -> ModuleType | None: + """Just return the 2016 version.""" + # this will already be in sys.modules because of the star import above + return sys.modules[self.module_2016] + + def exec_module(self, module: ModuleType) -> None: + """We never need to exec.""" + pass + + +# add a sys.meta_path hook to allow import of any modules in +# ome_types._autogenerated.ome_2016_06 +class OMEMetaPathFinder(MetaPathFinder): + def find_spec( + self, + fullname: str, + path: Sequence[str] | None, + target: ModuleType | None = None, + ) -> ModuleSpec | None: + """Return a module spec to redirect to ome_types._autogenerated.ome_2016_06.""" + if fullname.startswith("ome_types.model."): + return importlib.util.spec_from_loader(fullname, OME2016Loader(fullname)) + return None + + +sys.meta_path.append(OMEMetaPathFinder()) + +from ome_types.model._converters import register_converters # noqa + +register_converters() +del register_converters diff --git a/src/ome_types/model/_color.py b/src/ome_types/model/_color.py new file mode 100644 index 00000000..414b15a9 --- /dev/null +++ b/src/ome_types/model/_color.py @@ -0,0 +1,36 @@ +from contextlib import suppress +from typing import Tuple, Union + +from pydantic import color + +__all__ = ["Color"] + +RGBA = Tuple[int, int, int, float] +ColorType = Union[Tuple[int, int, int], RGBA, str, int] + + +class Color(color.Color): + def __init__(self, val: ColorType = -1) -> None: + with suppress(ValueError, TypeError): + val = self._int2tuple(int(val)) # type: ignore + super().__init__(val) # type: ignore [arg-type] + + @classmethod + def _int2tuple(cls, val: int) -> RGBA: + return (val >> 24 & 255, val >> 16 & 255, val >> 8 & 255, (val & 255) / 255) + + def as_int32(self) -> int: + """Convert to an int32, with alpha in the least significant byte.""" + r, g, b, *a = self.as_rgb_tuple() + v = r << 24 | g << 16 | b << 8 | int((a[0] if a else 1) * 255) + if v < 2**32 // 2: + return v + return v - 2**32 + + def __eq__(self, o: object) -> bool: + if isinstance(o, Color): + return self.as_int32() == o.as_int32() + return NotImplemented # pragma: no cover + + def __int__(self) -> int: + return self.as_int32() diff --git a/src/ome_types/model/_converters.py b/src/ome_types/model/_converters.py new file mode 100644 index 00000000..c157f324 --- /dev/null +++ b/src/ome_types/model/_converters.py @@ -0,0 +1,37 @@ +import datetime +import warnings +from typing import Any + +from xsdata.formats.converter import Converter, converter +from xsdata.models.datatype import XmlDateTime + +from ome_types.model._color import Color + + +class DateTimeConverter(Converter): + def serialize(self, value: datetime.datetime, **kwargs: Any) -> str: + return str(XmlDateTime.from_datetime(value)) + + def deserialize(self, value: Any, **kwargs: Any) -> datetime.datetime: + xmldt = XmlDateTime.from_string(value) + try: + return xmldt.to_datetime() + except ValueError as e: + msg = f"Invalid datetime: {value!r} {e}." + if xmldt.year <= 0: + msg += "(BC dates are not supported)" + warnings.warn(msg, stacklevel=2) + return datetime.datetime(1, 1, 1) + + +class ColorConverter(Converter): + def serialize(self, value: Color, **kwargs: Any) -> str: + return str(value.as_int32()) + + def deserialize(self, value: Any, **kwargs: Any) -> Color: + return Color(value) + + +def register_converters() -> None: + converter.register_converter(Color, ColorConverter()) + converter.register_converter(datetime.datetime, DateTimeConverter()) diff --git a/src/ome_types/model/_shape_union.py b/src/ome_types/model/_shape_union.py new file mode 100644 index 00000000..10bd571b --- /dev/null +++ b/src/ome_types/model/_shape_union.py @@ -0,0 +1,70 @@ +from contextlib import suppress +from typing import Dict, Iterator, List, Type, Union + +from pydantic import Field, ValidationError, validator + +from ome_types._autogenerated.ome_2016_06.ellipse import Ellipse +from ome_types._autogenerated.ome_2016_06.label import Label +from ome_types._autogenerated.ome_2016_06.line import Line +from ome_types._autogenerated.ome_2016_06.mask import Mask +from ome_types._autogenerated.ome_2016_06.point import Point +from ome_types._autogenerated.ome_2016_06.polygon import Polygon +from ome_types._autogenerated.ome_2016_06.polyline import Polyline +from ome_types._autogenerated.ome_2016_06.rectangle import Rectangle +from ome_types._mixins._base_type import OMEType +from ome_types.model._user_sequence import UserSequence + +ShapeType = Union[Rectangle, Mask, Point, Ellipse, Line, Polyline, Polygon, Label] +_KINDS: Dict[str, Type[ShapeType]] = { + "rectangle": Rectangle, + "mask": Mask, + "point": Point, + "ellipse": Ellipse, + "line": Line, + "polyline": Polyline, + "polygon": Polygon, + "label": Label, +} + +_ShapeCls = tuple(_KINDS.values()) + + +class ShapeUnion(OMEType, UserSequence[ShapeType]): # type: ignore[misc] + # NOTE: in reality, this is List[ShapeGroupType]... but + # for some reason that messes up xsdata data binding + __root__: List[object] = Field( + default_factory=list, + metadata={ + "type": "Elements", + "choices": tuple( + {"name": kind.title(), "type": cls} for kind, cls in _KINDS.items() + ), + }, + ) + + @validator("__root__", each_item=True) + def _validate_root(cls, v: ShapeType) -> ShapeType: + if isinstance(v, _ShapeCls): + return v + if isinstance(v, dict): + # NOTE: this is here to preserve the v1 behavior of passing a dict like + # {"kind": "label", "x": 0, "y": 0} + # to create a label rather than a point + if "kind" in v: + kind = v.pop("kind").lower() + return _KINDS[kind](**v) + + for cls_ in _ShapeCls: + with suppress(ValidationError): + return cls_(**v) + raise ValueError(f"Invalid shape: {v}") + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.__root__!r})" + + # overriding BaseModel.__iter__ to behave more like a real Sequence + def __iter__(self) -> Iterator[ShapeType]: # type: ignore[override] + yield from self.__root__ # type: ignore[misc] # see NOTE above + + def __eq__(self, _value: object) -> bool: + return _value == self.__root__ diff --git a/src/ome_types/model/_structured_annotations.py b/src/ome_types/model/_structured_annotations.py new file mode 100644 index 00000000..cbce27ca --- /dev/null +++ b/src/ome_types/model/_structured_annotations.py @@ -0,0 +1,69 @@ +from contextlib import suppress +from typing import Iterator, List + +from pydantic import Field, ValidationError, validator + +from ome_types._autogenerated.ome_2016_06.annotation import Annotation +from ome_types._autogenerated.ome_2016_06.boolean_annotation import BooleanAnnotation +from ome_types._autogenerated.ome_2016_06.comment_annotation import CommentAnnotation +from ome_types._autogenerated.ome_2016_06.double_annotation import DoubleAnnotation +from ome_types._autogenerated.ome_2016_06.file_annotation import FileAnnotation +from ome_types._autogenerated.ome_2016_06.list_annotation import ListAnnotation +from ome_types._autogenerated.ome_2016_06.long_annotation import LongAnnotation +from ome_types._autogenerated.ome_2016_06.map_annotation import MapAnnotation +from ome_types._autogenerated.ome_2016_06.tag_annotation import TagAnnotation +from ome_types._autogenerated.ome_2016_06.term_annotation import TermAnnotation +from ome_types._autogenerated.ome_2016_06.timestamp_annotation import ( + TimestampAnnotation, +) +from ome_types._autogenerated.ome_2016_06.xml_annotation import XMLAnnotation +from ome_types._mixins._base_type import OMEType +from ome_types.model._user_sequence import UserSequence + +AnnotationTypes = ( + XMLAnnotation, + FileAnnotation, + ListAnnotation, + LongAnnotation, + DoubleAnnotation, + CommentAnnotation, + BooleanAnnotation, + TimestampAnnotation, + TagAnnotation, + TermAnnotation, + MapAnnotation, +) + + +class StructuredAnnotationList(OMEType, UserSequence[Annotation]): # type: ignore[misc] + # NOTE: in reality, this is List[StructuredAnnotationTypes]... but + # for some reason that messes up xsdata data binding + __root__: List[object] = Field( + default_factory=list, + metadata={ + "type": "Elements", + "choices": tuple( + {"name": cls.__name__, "type": cls} for cls in AnnotationTypes + ), + }, + ) + + @validator("__root__", each_item=True) + def _validate_root(cls, v: Annotation) -> Annotation: + if isinstance(v, AnnotationTypes): + return v + if isinstance(v, dict): + for cls_ in AnnotationTypes: + with suppress(ValidationError): + return cls_(**v) + raise ValueError(f"Invalid Annotation: {v} of type {type(v)}") + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.__root__!r})" + + # overriding BaseModel.__iter__ to behave more like a real Sequence + def __iter__(self) -> Iterator[Annotation]: # type: ignore[override] + yield from self.__root__ # type: ignore[misc] # see NOTE above + + def __eq__(self, _value: object) -> bool: + return _value == self.__root__ diff --git a/src/ome_types/model/_user_sequence.py b/src/ome_types/model/_user_sequence.py new file mode 100644 index 00000000..71153077 --- /dev/null +++ b/src/ome_types/model/_user_sequence.py @@ -0,0 +1,48 @@ +from typing import Iterable, List, MutableSequence, TypeVar, Union, overload + +T = TypeVar("T") + + +class UserSequence(MutableSequence[T]): + """Generric Mutable sequence, that expects the real list at __root__.""" + + __root__: List[object] + + def __repr__(self) -> str: + return repr(self.__root__) + + def __delitem__(self, _idx: Union[int, slice]) -> None: + del self.__root__[_idx] + + @overload + def __getitem__(self, _idx: int) -> T: + ... + + @overload + def __getitem__(self, _idx: slice) -> List[T]: + ... + + def __getitem__(self, _idx: Union[int, slice]) -> Union[T, List[T]]: + return self.__root__[_idx] # type: ignore[return-value] + + def __len__(self) -> int: + return len(self.__root__) + + @overload + def __setitem__(self, _idx: int, _val: T) -> None: + ... + + @overload + def __setitem__(self, _idx: slice, _val: Iterable[T]) -> None: + ... + + def __setitem__(self, _idx: Union[int, slice], _val: Union[T, Iterable[T]]) -> None: + self.__root__[_idx] = _val # type: ignore[index] + + def insert(self, index: int, value: T) -> None: + self.__root__.insert(index, value) + + # for some reason, without overloading this... append() adds things to the + # beginning of the list instead of the end + def append(self, value: T) -> None: + self.__root__.append(value) diff --git a/src/ome_types/model/simple_types.py b/src/ome_types/model/simple_types.py new file mode 100644 index 00000000..7ec4c590 --- /dev/null +++ b/src/ome_types/model/simple_types.py @@ -0,0 +1,124 @@ +"""This module is only here for backwards compatibility. + +Should add a deprecation warning. +""" +from enum import Enum + +from ome_types._autogenerated.ome_2016_06 import ( + Binning, + Marker, + NamingConvention, + PixelType, + UnitsElectricPotential, + UnitsFrequency, + UnitsLength, + UnitsPower, + UnitsPressure, + UnitsTemperature, + UnitsTime, +) +from ome_types._autogenerated.ome_2016_06.shape import FontFamily +from ome_types.model._color import Color + +__all__ = [ + "PixelType", + "Binning", + "FontFamily", + "Hex40", + "LSID", + "Marker", + "Color", + "NamingConvention", + "NonNegativeFloat", + "NonNegativeInt", + "NonNegativeLong", + "PercentFraction", + "PixelType", + "PositiveFloat", + "PositiveInt", + "UnitsAngle", + "UnitsElectricPotential", + "UnitsFrequency", + "UnitsLength", + "UnitsPower", + "UnitsPressure", + "UnitsTemperature", + "UnitsTime", + "UniversallyUniqueIdentifier", + "AnnotationID", + "ChannelID", + "DatasetID", + "DetectorID", + "DichroicID", + "ExperimenterGroupID", + "ExperimenterID", + "ExperimentID", + "FilterID", + "FilterSetID", + "FolderID", + "ImageID", + "InstrumentID", + "LightSourceID", + "MicrobeamManipulationID", + "ModuleID", + "ObjectiveID", + "PixelsID", + "PlateAcquisitionID", + "PlateID", + "ProjectID", + "ReagentID", + "ROIID", + "ScreenID", + "ShapeID", + "WellID", + "WellSampleID", +] + + +Hex40 = bytes +NonNegativeFloat = float +NonNegativeInt = int +NonNegativeLong = int +PercentFraction = float +PositiveFloat = float +PositiveInt = int +UniversallyUniqueIdentifier = str + +# IDs + +LSID = str +AnnotationID = str +ChannelID = str +DatasetID = str +DetectorID = str +DichroicID = str +ExperimenterGroupID = str +ExperimenterID = str +ExperimentID = str +FilterID = str +FilterSetID = str +FolderID = str +ImageID = str +InstrumentID = str +LightSourceID = str +MicrobeamManipulationID = str +ModuleID = str +ObjectiveID = str +PixelsID = str +ROIID = str +PlateAcquisitionID = str +PlateID = str +ProjectID = str +ReagentID = str +ScreenID = str +ShapeID = str +WellID = str +WellSampleID = str + + +class UnitsAngle(Enum): + """The units used to represent an angle.""" + + DEGREE = "deg" + GRADIAN = "gon" + RADIAN = "rad" diff --git a/src/ome_types/schema.py b/src/ome_types/schema.py deleted file mode 100644 index 0ff30eb0..00000000 --- a/src/ome_types/schema.py +++ /dev/null @@ -1,11 +0,0 @@ -import warnings - -from ._xmlschema import * # noqa - -warnings.warn( - "Direct import from ome_types.schema is deprecated. " - "Please import convenience functions directly from ome_types. " - "This will raise an error in the future.", - FutureWarning, - stacklevel=2, -) diff --git a/src/ome_types/_units.py b/src/ome_types/units.py similarity index 78% rename from src/ome_types/_units.py rename to src/ome_types/units.py index 4d2e18ec..088ca164 100644 --- a/src/ome_types/_units.py +++ b/src/ome_types/units.py @@ -1,6 +1,6 @@ import pint -ureg = pint.UnitRegistry(auto_reduce_dimensions=True) +ureg: pint.UnitRegistry = pint.UnitRegistry(auto_reduce_dimensions=True) ureg.define("reference_frame = [_reference_frame]") ureg.define("@alias grade = gradian") ureg.define("@alias astronomical_unit = ua") diff --git a/src/ome_types/util.py b/src/ome_types/util.py deleted file mode 100644 index 7b418480..00000000 --- a/src/ome_types/util.py +++ /dev/null @@ -1,106 +0,0 @@ -from __future__ import annotations - -import re -from functools import lru_cache -from pathlib import Path -from typing import TYPE_CHECKING, Any - -from . import model -from ._base_type import OMEType -from .model.reference import Reference - -if TYPE_CHECKING: - from .model.simple_types import LSID - - -def cast_number(qnum: str) -> str | int | float: - """Attempt to cast a number from a string. - - This function attempts to cast a string to a number. It will first try to parse an - int, then a float, and finally returns a string if both fail. - """ - try: - return int(qnum) - except ValueError: - try: - return float(qnum) - except ValueError: - return qnum - - -def collect_references(value: Any) -> list[Reference]: - """Return a list of all References contained in value. - - Recursively walks all dataclass fields and iterates over lists. The base - case is when value is either a Reference object, or an uninteresting type - that we don't need to inspect further. - - """ - references: list[Reference] = [] - if isinstance(value, Reference): - references.append(value) - elif isinstance(value, list): - for v in value: - references.extend(collect_references(v)) - elif isinstance(value, OMEType): - for f in value.__fields__: - references.extend(collect_references(getattr(value, f))) - # Do nothing for uninteresting types - return references - - -def collect_ids(value: Any) -> dict[LSID, OMEType]: - """Return a map of all model objects contained in value, keyed by id. - - Recursively walks all dataclass fields and iterates over lists. The base - case is when value is neither a dataclass nor a list. - """ - ids: dict[LSID, OMEType] = {} - if isinstance(value, list): - for v in value: - ids.update(collect_ids(v)) - elif isinstance(value, OMEType): - for f in value.__fields__: - if f == "id" and not isinstance(value, Reference): - # We don't need to recurse on the id string, so just record it - # and move on. - ids[value.id] = value # type: ignore - else: - ids.update(collect_ids(getattr(value, f))) - # Do nothing for uninteresting types. - return ids - - -CAMEL_REGEX = re.compile(r"(? str: - """Return a snake_case version of a camelCase string.""" - return model._camel_to_snake.get(name, CAMEL_REGEX.sub("_", name).lower()) - - -@lru_cache() -def norm_key(key: str) -> str: - """Return a normalized key.""" - return key.split("}")[-1] - - -def _get_plural(key: str, tag: str) -> str: - return model._singular_to_plural.get((norm_key(tag), key), key) - - -def _ensure_xml_bytes(path_or_str: Path | str | bytes) -> bytes: - """Ensure that `path_or_str` is bytes. Read from disk if it's an existing file.""" - if isinstance(path_or_str, Path): - return path_or_str.read_bytes() - if isinstance(path_or_str, str): - # FIXME: deal with magic number 10. I think it's to avoid Path.exists - # failure on a full string - if "xml" not in path_or_str[:10] and Path(path_or_str).exists(): - return Path(path_or_str).read_bytes() - else: - return path_or_str.encode() - if isinstance(path_or_str, bytes): - return path_or_str - raise TypeError("path_or_str must be one of [Path, str, bytes]. I") diff --git a/src/ome_types/validation.py b/src/ome_types/validation.py new file mode 100644 index 00000000..2d287b9b --- /dev/null +++ b/src/ome_types/validation.py @@ -0,0 +1,85 @@ +from __future__ import annotations + +import io +import os +from contextlib import suppress +from functools import lru_cache +from pathlib import Path +from typing import IO, TYPE_CHECKING, Union + +if TYPE_CHECKING: + import xmlschema + + XMLSourceType = Union[str, bytes, Path, IO[str], IO[bytes]] + +NS_OME = r"{http://www.openmicroscopy.org/Schemas/OME/2016-06}" +OME_2016_06_XSD = str(Path(__file__).parent / "ome-2016-06.xsd") + + +class ValidationError(ValueError): + ... + + +def validate_xml(xml: XMLSourceType, schema: Path | str | None = None) -> None: + """Validate XML against an XML Schema. + + By default, will validate against the OME 2016-06 schema. + """ + with suppress(ImportError): + return validate_xml_with_lxml(xml, schema) + + with suppress(ImportError): # pragma: no cover + return validate_xml_with_xmlschema(xml, schema) + + raise ImportError( # pragma: no cover + "Validation requires either `lxml` or `xmlschema`. " + "Please pip install one of them." + ) from None + + +def validate_xml_with_lxml( + xml: XMLSourceType, schema: Path | str | None = None +) -> None: + """Validate XML against an XML Schema using lxml.""" + from lxml import etree + + tree = etree.parse(schema or OME_2016_06_XSD) # noqa: S320 + xmlschema = etree.XMLSchema(tree) + + if isinstance(xml, (str, bytes)) and not os.path.isfile(xml): + xml = io.BytesIO(xml.encode("utf-8") if isinstance(xml, str) else xml) + doc = etree.parse(xml) # noqa: S320 + else: + doc = etree.parse(xml) # noqa: S320 + + if not xmlschema.validate(doc): + msg = f"Validation of {str(xml)[:20]!r} failed:" + for error in xmlschema.error_log: + msg += f"\n - line {error.line}: {error.message}" + raise ValidationError(msg) + + +def validate_xml_with_xmlschema( + xml: XMLSourceType, schema: Path | str | None = None +) -> None: + """Validate XML against an XML Schema using xmlschema.""" + from xmlschema.exceptions import XMLSchemaException + + xmlschema = _get_XMLSchema(schema or OME_2016_06_XSD) + try: + xmlschema.validate(xml) + except XMLSchemaException as e: + raise ValidationError(str(e)) from None + + +@lru_cache(maxsize=None) +def _get_XMLSchema(schema: Path | str) -> xmlschema.XMLSchema: + import xmlschema + + xml_schema = xmlschema.XMLSchema(schema) + # FIXME Hack to work around xmlschema poor support for keyrefs to + # substitution groups + ls_sgs = xml_schema.maps.substitution_groups[f"{NS_OME}LightSourceGroup"] + ls_id_maps = xml_schema.maps.identities[f"{NS_OME}LightSourceIDKey"] + ls_id_maps.elements = {e: None for e in ls_sgs} + return xml_schema diff --git a/src/ome_types/widgets.py b/src/ome_types/widgets.py deleted file mode 100644 index 94f6ce72..00000000 --- a/src/ome_types/widgets.py +++ /dev/null @@ -1,166 +0,0 @@ -import os -import warnings -from typing import TYPE_CHECKING, Any, Optional, Union - -from .model import OME - -try: - from qtpy.QtCore import QMimeData, Qt - from qtpy.QtWidgets import QTreeWidget, QTreeWidgetItem -except ImportError as e: - raise ImportError( - "qtpy and a Qt backend (pyside or pyqt) is required to use the OME widget:\n" - "pip install qtpy pyqt5" - ) from e - - -if TYPE_CHECKING: - import napari - - -class OMETree(QTreeWidget): - """A Widget that can show OME XML.""" - - def __init__( - self, - ome_dict: Optional[dict] = None, - viewer: "napari.viewer.Viewer" = None, - parent=None, - ) -> None: - super().__init__(parent=parent) - self._viewer = viewer - self.setAcceptDrops(True) - self.setDropIndicatorShown(True) - self.setIndentation(15) - - item = self.headerItem() - font = item.font(0) - font.setBold(True) - item.setFont(0, font) - self.clear() - - self._current_path: Optional[str] = None - if ome_dict: - self.update(ome_dict) - - if viewer is not None: - viewer.layers.selection.events.active.connect( - lambda e: self._try_load_layer(e.value) - ) - self._try_load_layer(viewer.layers.selection.active) - - def clear(self) -> None: - """Clear the widget and reset the header text.""" - self.headerItem().setText(0, "drag/drop file...") - super().clear() - - def _try_load_layer(self, layer: "napari.layers.Layer"): - """Handle napari viewer behavior.""" - from ._napari_plugin import METADATA_KEY - - if layer is not None: - path = str(layer.source.path) - - # deprecated... don't do this ... it should be a dict - if callable(layer.metadata): - ome_meta = layer.metadata() - elif isinstance(layer.metadata, OME): - ome_meta = layer.metadata - else: - ome_meta = layer.metadata.get(METADATA_KEY) - if callable(ome_meta): - ome_meta = ome_meta() - - ome = None - if isinstance(ome_meta, OME): - ome = ome_meta - elif path.endswith((".tiff", ".tif")) and path != self._current_path: - try: - ome = OME.from_tiff(path) - except Exception: - return - if isinstance(ome, OME): - self._current_path = path - self.update(ome) - self.headerItem().setText(0, os.path.basename(path)) - else: - self._current_path = None - self.clear() - - def update(self, ome: Union[OME, str]) -> None: - """Update the widget with a new OME object or path to an OME XML file.""" - if not ome: - return - if isinstance(ome, OME): - _ome = ome - elif isinstance(ome, str): - if ome == self._current_path: - return - try: - if ome.endswith(".xml"): - _ome = OME.from_xml(ome) - elif ome.lower().endswith((".tif", ".tiff")): - _ome = OME.from_tiff(ome) - else: - warnings.warn(f"Unrecognized file type: {ome}", stacklevel=2) - return - except Exception as e: - warnings.warn( - f"Could not parse OME metadata from {ome}: {e}", stacklevel=2 - ) - return - self.headerItem().setText(0, os.path.basename(ome)) - self._current_path = ome - else: - raise TypeError("must be OME object or string") - self._fill_item(_ome.dict(exclude_unset=True)) - - def _fill_item(self, obj, item: QTreeWidgetItem = None): - if item is None: - self.clear() - item = self.invisibleRootItem() - if isinstance(obj, dict): - for key, val in sorted(obj.items()): - child = QTreeWidgetItem([key]) - item.addChild(child) - self._fill_item(val, child) - elif isinstance(obj, (list, tuple)): - for n, val in enumerate(obj): - text = val.get("id", n) if hasattr(val, "get") else n - child = QTreeWidgetItem([str(text)]) - item.addChild(child) - self._fill_item(val, child) - else: - t = getattr(obj, "value", str(obj)) - item.setText(0, f"{item.text(0)}: {t}") - - def dropMimeData( - self, parent: QTreeWidgetItem, index: int, data: QMimeData, _: Any - ) -> bool: - """Handle drag/drop events to load OME XML files.""" - if data.hasUrls(): - for url in data.urls(): - lf = url.toLocalFile() - if lf.endswith((".xml", ".tiff", ".tif")): - self.update(lf) - return True - return False - - def mimeTypes(self) -> list[str]: - """Return the supported mime types for drag/drop events.""" - return ["text/uri-list"] - - def supportedDropActions(self) -> "Qt.DropActions": - """Return the supported drop actions for drag/drop events.""" - return Qt.CopyAction - - -if __name__ == "__main__": - from qtpy.QtWidgets import QApplication - - app = QApplication([]) - - widget = OMETree() - widget.show() - - app.exec() diff --git a/src/xsdata_pydantic_basemodel/__init__.py b/src/xsdata_pydantic_basemodel/__init__.py new file mode 100644 index 00000000..a0b14a34 --- /dev/null +++ b/src/xsdata_pydantic_basemodel/__init__.py @@ -0,0 +1,4 @@ +"""xsdata pydantic plugin using BaseModel.""" +__version__ = "0.1.0" # FIXME when you pull it out of the ome-types repo +__author__ = "Talley Lambert" +__email__ = "talley.lambert@gmail.com" diff --git a/src/xsdata_pydantic_basemodel/bindings.py b/src/xsdata_pydantic_basemodel/bindings.py new file mode 100644 index 00000000..9a27798e --- /dev/null +++ b/src/xsdata_pydantic_basemodel/bindings.py @@ -0,0 +1,168 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, Callable, Generator, Iterator +from xml.etree.ElementTree import QName + +from xsdata.formats.dataclass import context, parsers, serializers +from xsdata.formats.dataclass.serializers import config +from xsdata.formats.dataclass.serializers.mixins import XmlWriterEvent +from xsdata.models.enums import QNames +from xsdata.utils import collections, namespaces +from xsdata.utils.constants import EMPTY_MAP, return_input + +if TYPE_CHECKING: + from pydantic import BaseModel + from xsdata.formats.dataclass.models.elements import XmlMeta + + +class XmlContext(context.XmlContext): + """Pydantic BaseModel ready xml context instance.""" + + def __init__( + self, + element_name_generator: Callable = return_input, + attribute_name_generator: Callable = return_input, + ): + super().__init__( + element_name_generator, attribute_name_generator, "pydantic-basemodel" + ) + + +class SerializerConfig(config.SerializerConfig): + # here to add `ignore_unset_attributes` support to XmlSerializer + __slots__ = ( + *config.SerializerConfig.__slots__, + "ignore_unset_attributes", + "attribute_sort_key", + ) + + def __init__(self, **kwargs: Any) -> None: + self.ignore_unset_attributes = kwargs.pop("ignore_unset_attributes", False) + self.attribute_sort_key = kwargs.pop("attribute_sort_key", None) + super().__init__(**kwargs) + + +@dataclass +class XmlParser(parsers.XmlParser): + context: XmlContext = field(default_factory=XmlContext) + + +@dataclass +class XmlSerializer(serializers.XmlSerializer): + context: XmlContext = field(default_factory=XmlContext) + + # overriding so we can pass the args we want to next_attribute + # and so that we can skip unset values + def write_dataclass( + self, + obj: BaseModel, + namespace: str | None = None, + qname: str | None = None, + nillable: bool = False, + xsi_type: str | None = None, + ) -> Generator: + """ + Produce an events stream from a dataclass. + + Optionally override the qualified name and the xsi properties + type and nil. + """ + meta = self.context.build( + obj.__class__, namespace, globalns=self.config.globalns + ) + qname = qname or meta.qname + nillable = nillable or meta.nillable + namespace, tag = namespaces.split_qname(qname) + + yield XmlWriterEvent.START, qname + + # XXX: reason 1 for overriding. + ignore_unset = getattr(self.config, "ignore_unset_attributes", False) + for key, value in self.next_attribute( + obj, + meta, + nillable, + xsi_type, + self.config.ignore_default_attributes, + ignore_unset, + getattr(self.config, "attribute_sort_key", None), + ): + yield XmlWriterEvent.ATTR, key, value + + for var, value in self.next_value(obj, meta): + # XXX: reason 2 for overriding. + if ignore_unset and var.name not in obj.__fields_set__: + continue + yield from self.write_value(value, var, namespace) + + yield XmlWriterEvent.END, qname + + # overriding so we can implement support for `ignore_unset_attributes` + # and so that we can sort attributes as we want + @classmethod + def next_attribute( + cls, + obj: BaseModel, + meta: XmlMeta, + nillable: bool, + xsi_type: str | None, + ignore_optionals: bool, + ignore_unset: bool = False, + attribute_sort_key: Callable | None = None, + ) -> Iterator[tuple[str, Any]]: + """ + Return the attribute variables with their object values if set and not + empty iterables. + + :param obj: Input object + :param meta: Object metadata + :param nillable: Is model nillable + :param xsi_type: The true xsi:type of the object + :param ignore_optionals: Skip optional attributes with default + value + :return: + """ + + set_fields = obj.__fields_set__ if ignore_unset else set() + vars_ = meta.get_attribute_vars() + if attribute_sort_key is not None: + vars_ = sorted(meta.get_attribute_vars(), key=attribute_sort_key) + + # ^^^ new + + for var in vars_: + if var.is_attribute: + value = getattr(obj, var.name) + if ( + value is None + or (collections.is_array(value) and not value) + or (ignore_optionals and var.is_optional(value)) + or (ignore_unset and var.name not in set_fields) # new + ): + continue + + yield var.qname, cls.encode(value, var) + else: + yield from getattr(obj, var.name, EMPTY_MAP).items() + + if xsi_type: + yield QNames.XSI_TYPE, QName(xsi_type) + + if nillable: + yield QNames.XSI_NIL, "true" + + +@dataclass +class JsonParser(parsers.JsonParser): + context: XmlContext = field(default_factory=XmlContext) + + +@dataclass +class JsonSerializer(serializers.JsonSerializer): + context: XmlContext = field(default_factory=XmlContext) + + +@dataclass +class UserXmlParser(parsers.UserXmlParser): + context: XmlContext = field(default_factory=XmlContext) diff --git a/src/xsdata_pydantic_basemodel/compat.py b/src/xsdata_pydantic_basemodel/compat.py new file mode 100644 index 00000000..d5202b92 --- /dev/null +++ b/src/xsdata_pydantic_basemodel/compat.py @@ -0,0 +1,149 @@ +from dataclasses import MISSING, field +from typing import ( + Any, + Callable, + Dict, + Generic, + List, + Optional, + Tuple, + Type, + TypeVar, + cast, +) +from xml.etree.ElementTree import QName + +from pydantic import BaseModel, validators +from pydantic.fields import Field, ModelField, Undefined +from xsdata.formats.dataclass.compat import Dataclasses, class_types +from xsdata.formats.dataclass.models.elements import XmlType +from xsdata.models.datatype import XmlDate, XmlDateTime, XmlDuration, XmlPeriod, XmlTime + +T = TypeVar("T", bound=object) + + +class AnyElement(BaseModel): + """Generic model to bind xml document data to wildcard fields. + + :param qname: The element's qualified name + :param text: The element's text content + :param tail: The element's tail content + :param children: The element's list of child elements. + :param attributes: The element's key-value attribute mappings. + """ + + qname: Optional[str] = Field(default=None) + text: Optional[str] = Field(default=None) + tail: Optional[str] = Field(default=None) + children: List[object] = Field( + default_factory=list, metadata={"type": XmlType.WILDCARD} + ) + attributes: Dict[str, str] = Field( + default_factory=dict, metadata={"type": XmlType.ATTRIBUTES} + ) + + class Config: + arbitrary_types_allowed = True + + +class DerivedElement(BaseModel, Generic[T]): + """Generic model wrapper for type substituted elements. + + Example: eg. ... + + :param qname: The element's qualified name + :param value: The wrapped value + :param type: The real xsi:type + """ + + qname: str + value: T + type: Optional[str] = None + + class Config: + arbitrary_types_allowed = True + + +class PydanticBaseModel(Dataclasses): + @property + def any_element(self) -> Type: + return AnyElement + + @property + def derived_element(self) -> Type: + return DerivedElement + + def is_model(self, obj: Any) -> bool: + clazz = obj if isinstance(obj, type) else type(obj) + if issubclass(clazz, BaseModel): + clazz.update_forward_refs() + return True + + return False + + def get_fields(self, obj: Any) -> Tuple[Any, ...]: + _fields = cast("BaseModel", obj).__fields__.values() + return tuple(_pydantic_field_to_dataclass_field(field) for field in _fields) + + +def _pydantic_field_to_dataclass_field(pydantic_field: ModelField) -> Any: + if pydantic_field.default_factory is not None: + default_factory: Any = pydantic_field.default_factory + default = MISSING + else: + default_factory = MISSING + default = ( + MISSING + if pydantic_field.default in (Undefined, Ellipsis) + else pydantic_field.default + ) + + dataclass_field = field( # type: ignore + default=default, + default_factory=default_factory, + # init=True, + # hash=None, + # compare=True, + metadata=pydantic_field.field_info.extra.get("metadata", {}), + # kw_only=MISSING, + ) + dataclass_field.name = pydantic_field.name + dataclass_field.type = pydantic_field.type_ + return dataclass_field + + +class_types.register("pydantic-basemodel", PydanticBaseModel()) + + +def make_validators(tp: Type, factory: Callable) -> List[Callable]: + def validator(value: Any) -> Any: + if isinstance(value, tp): + return value + + if isinstance(value, str): + return factory(value) + + raise ValueError + + return [validator] + + +if hasattr(validators, "_VALIDATORS"): + validators._VALIDATORS.extend( + [ + (XmlDate, make_validators(XmlDate, XmlDate.from_string)), + (XmlDateTime, make_validators(XmlDateTime, XmlDateTime.from_string)), + (XmlTime, make_validators(XmlTime, XmlTime.from_string)), + (XmlDuration, make_validators(XmlDuration, XmlDuration)), + (XmlPeriod, make_validators(XmlPeriod, XmlPeriod)), + (QName, make_validators(QName, QName)), + ] + ) +else: + import warnings + + warnings.warn( + "Could not find pydantic.validators._VALIDATORS." + "xsdata-pydantic-basemodel may be incompatible with your pydantic version.", + stacklevel=2, + ) diff --git a/src/xsdata_pydantic_basemodel/generator.py b/src/xsdata_pydantic_basemodel/generator.py new file mode 100644 index 00000000..2b161778 --- /dev/null +++ b/src/xsdata_pydantic_basemodel/generator.py @@ -0,0 +1,81 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from xsdata.formats.dataclass.filters import Filters +from xsdata.formats.dataclass.generator import DataclassGenerator +from xsdata.utils.collections import unique_sequence +from xsdata.utils.text import stop_words + +if TYPE_CHECKING: + from xsdata.codegen.models import Attr, Class + from xsdata.models.config import GeneratorConfig, OutputFormat + +stop_words.update(("schema", "validate")) + + +class PydanticBaseGenerator(DataclassGenerator): + """Python pydantic dataclasses code generator.""" + + @classmethod + def init_filters(cls, config: GeneratorConfig) -> Filters: + return PydanticBaseFilters(config) + + +class PydanticBaseFilters(Filters): + @classmethod + def build_import_patterns(cls) -> dict[str, dict]: + patterns = Filters.build_import_patterns() + patterns.update( + { + "pydantic": { + "Field": [" = Field("], + "BaseModel": ["BaseModel"], + } + } + ) + return {key: patterns[key] for key in sorted(patterns)} + + @classmethod + def build_class_annotation(cls, fmt: OutputFormat) -> str: + # remove the @dataclass decorator + return "" + + def field_definition( + self, attr: Attr, ns_map: dict, parent_namespace: str | None, parents: list[str] + ) -> str: + defn = super().field_definition(attr, ns_map, parent_namespace, parents) + return defn.replace("field(", "Field(") + + def format_arguments(self, kwargs: dict, indent: int = 0) -> str: + # called by field_definition + self.move_metadata_to_pydantic_field(kwargs) + return super().format_arguments(kwargs, indent) + + def class_bases(self, obj: Class, class_name: str) -> list[str]: + # add BaseModel to the class bases + # FIXME ... need to dedupe superclasses + bases = super().class_bases(obj, class_name) + return unique_sequence([*bases, "BaseModel"]) + + def move_metadata_to_pydantic_field(self, kwargs: dict, pop: bool = False) -> None: + """Move metadata from the metadata dict to the pydantic Field kwargs.""" + # XXX: can we pop them? or does xsdata need them in the metdata dict as well? + if "metadata" not in kwargs: # pragma: no cover + return + + metadata: dict = kwargs["metadata"] + getitem = metadata.pop if pop else metadata.get + for from_, to_ in [ + ("min_inclusive", "ge"), + ("min_exclusive", "gt"), + ("max_inclusive", "le"), + ("max_exclusive", "lt"), + ("min_occurs", "min_items"), + ("max_occurs", "max_items"), + ("pattern", "regex"), + ("min_length", "min_length"), + ("max_length", "max_length"), + ]: + if from_ in metadata: + kwargs[to_] = getitem(from_) diff --git a/src/xsdata_pydantic_basemodel/hooks/__init__.py b/src/xsdata_pydantic_basemodel/hooks/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/xsdata_pydantic_basemodel/hooks/class_type.py b/src/xsdata_pydantic_basemodel/hooks/class_type.py new file mode 100644 index 00000000..bd33c205 --- /dev/null +++ b/src/xsdata_pydantic_basemodel/hooks/class_type.py @@ -0,0 +1,5 @@ +from xsdata.formats.dataclass.compat import class_types + +from xsdata_pydantic_basemodel.compat import PydanticBaseModel + +class_types.register("pydantic-basemodel", PydanticBaseModel()) diff --git a/src/xsdata_pydantic_basemodel/hooks/cli.py b/src/xsdata_pydantic_basemodel/hooks/cli.py new file mode 100644 index 00000000..a962e295 --- /dev/null +++ b/src/xsdata_pydantic_basemodel/hooks/cli.py @@ -0,0 +1,5 @@ +from xsdata.codegen.writer import CodeWriter + +from xsdata_pydantic_basemodel.generator import PydanticBaseGenerator + +CodeWriter.register_generator("pydantic-basemodel", PydanticBaseGenerator) diff --git a/src/xsdata_pydantic_basemodel/py.typed b/src/xsdata_pydantic_basemodel/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..74f2dc9e --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,141 @@ +from __future__ import annotations + +import sys +from pathlib import Path +from typing import TYPE_CHECKING, Any, DefaultDict + +import pytest + +from ome_types import model +from ome_types._mixins import _base_type + +if TYPE_CHECKING: + from collections import defaultdict + +DATA = Path(__file__).parent / "data" +ALL_XML = set(DATA.glob("*.ome.xml")) +INVALID = {DATA / "invalid_xml_annotation.ome.xml", DATA / "bad.ome.xml"} + + +def _true_stem(p: Path) -> str: + return p.name.partition(".")[0] + + +@pytest.fixture(params=sorted(ALL_XML), ids=_true_stem) +def any_xml(request: pytest.FixtureRequest) -> Path: + return request.param + + +@pytest.fixture(params=sorted(ALL_XML - INVALID), ids=_true_stem) +def valid_xml(request: pytest.FixtureRequest) -> Path: + return request.param + + +@pytest.fixture(params=sorted(INVALID), ids=_true_stem) +def invalid_xml(request: pytest.FixtureRequest) -> Path: + return request.param + + +@pytest.fixture +def single_xml() -> Path: + return DATA / "example.ome.xml" + + +def pytest_addoption(parser: pytest.Parser) -> None: + parser.addoption( + "--ome-watch", + action="store_true", + default=False, + help="Monitor instantiation of all OME objects.", + ) + + +USED_CLASS_KWARGS: defaultdict[str, set[str]] = DefaultDict(set) + + +@pytest.fixture(autouse=True, scope="session") +def _monitor(request: pytest.FixtureRequest) -> None: + """Monitor instantiation of all OME objects. + + This is another form of coverage... to see what fields are actually being tested + by our data + """ + if not request.config.getoption("--ome-watch"): + return + + original = _base_type.OMEType.__init__ + + def patched(__pydantic_self__, **kwargs: Any) -> None: + original(__pydantic_self__, **kwargs) + USED_CLASS_KWARGS[__pydantic_self__.__class__.__name__].update(kwargs) + + _base_type.OMEType.__init__ = patched + + +def print_unused_kwargs() -> None: + """Print a table of unused kwargs for each class.""" + from rich.console import Console + from rich.table import Table + + table = Table(title="Class usage summary", border_style="cyan", expand=True) + table.add_column("Class", style="cyan") + table.add_column("Unused fields", style="red", max_width=50) + table.add_column("Percent Used", style="Green") + + rows: list[tuple[str, str, float]] = [] + total_fields = 0 + total_used = 0 + for cls_name in dir(model): + # loop over all classes in the model + cls = getattr(model, cls_name, None) + if not isinstance(cls, type): + continue + + # get a list of all fields (ignore refs) + all_fields: set[str] = set(getattr(cls, "__fields__", {})) + all_fields = {f for f in all_fields if not f.rstrip("s").endswith("ref")} + for base_cls in cls.__bases__: + all_fields -= set(getattr(base_cls, "__fields__", {})) + if not all_fields: + continue + + # determine how many have been used + used = USED_CLASS_KWARGS[cls_name] + unused_fields = all_fields - used + total_fields += len(all_fields) + total_used += len(all_fields) - len(unused_fields) + percent_used = 100 * (1 - len(unused_fields) / len(all_fields)) + if percent_used < 100: + rows.append((cls_name, ", ".join(unused_fields), percent_used)) + + # sort by percent used + for row in sorted(rows, key=lambda r: r[2]): + name, unused, percent = row + table.add_row(name, unused, f"{percent:.1f}%") + # add total + table.add_row( + "[bold yellow]TOTAL[/bold yellow]", + "", + f"{100 * total_used / total_fields:.1f}%", + ) + # print + Console().print(table) + + +@pytest.hookimpl(trylast=True) +def pytest_configure(config: pytest.Config) -> None: + if not config.getoption("--ome-watch"): + return + + from _pytest.terminal import TerminalReporter + + class CustomTerminalReporter(TerminalReporter): # type: ignore + def summary_stats(self) -> None: + super().summary_stats() + print_unused_kwargs() + + # Get the standard terminal reporter plugin and replace it with our + standard_reporter = config.pluginmanager.getplugin("terminalreporter") + custom_reporter = CustomTerminalReporter(config, sys.stdout) + config.pluginmanager.unregister(standard_reporter) + config.pluginmanager.register(custom_reporter, "terminalreporter") diff --git a/tests/data/old_model.json b/tests/data/old_model.json new file mode 100644 index 00000000..a0c0f0f9 --- /dev/null +++ b/tests/data/old_model.json @@ -0,0 +1,370 @@ +{ + "binary_only": { + "metadata_file": "str", + "uuid": "UniversallyUniqueIdentifier" + }, + "creator": "str", + "datasets": { + "id": "DatasetID", + "annotation_ref": { "id": "AnnotationID" }, + "description": "str", + "experimenter_group_ref": { "id": "ExperimenterGroupID" }, + "experimenter_ref": { "id": "ExperimenterID" }, + "image_ref": { "id": "ImageID" }, + "name": "str" + }, + "experimenter_groups": { + "id": "ExperimenterGroupID", + "annotation_ref": { "id": "AnnotationID" }, + "description": "str", + "experimenter_ref": { "id": "ExperimenterID" }, + "leader": { "id": "ExperimenterID" }, + "name": "str" + }, + "experimenters": { + "id": "ExperimenterID", + "annotation_ref": { "id": "AnnotationID" }, + "email": "str", + "first_name": "str", + "institution": "str", + "last_name": "str", + "middle_name": "str", + "user_name": "str" + }, + "experiments": { + "id": "ExperimentID", + "description": "str", + "experimenter_ref": { "id": "ExperimenterID" }, + "microbeam_manipulations": { + "experimenter_ref": { "id": "ExperimenterID" }, + "id": "MicrobeamManipulationID", + "roi_ref": { "id": "ROIID" }, + "description": "str", + "light_source_settings": { + "id": "LightSourceID", + "attenuation": "PercentFraction", + "wavelength": "PositiveFloat", + "wavelength_unit": "UnitsLength" + }, + "type": "List" + }, + "type": "List" + }, + "folders": { + "id": "FolderID", + "annotation_ref": { "id": "AnnotationID" }, + "description": "str", + "folder_ref": { "id": "FolderID" }, + "image_ref": { "id": "ImageID" }, + "name": "str", + "roi_ref": { "id": "ROIID" } + }, + "images": { + "id": "ImageID", + "pixels": { + "dimension_order": "DimensionOrder", + "id": "PixelsID", + "size_c": "PositiveInt", + "size_t": "PositiveInt", + "size_x": "PositiveInt", + "size_y": "PositiveInt", + "size_z": "PositiveInt", + "type": "PixelType", + "big_endian": "bool", + "bin_data": { + "value": "str", + "big_endian": "bool", + "length": "int", + "compression": "Compression" + }, + "channels": { + "id": "ChannelID", + "acquisition_mode": "AcquisitionMode", + "annotation_ref": { "id": "AnnotationID" }, + "color": "Color", + "contrast_method": "ContrastMethod", + "detector_settings": { + "id": "DetectorID", + "binning": "Binning", + "gain": "float", + "integration": "PositiveInt", + "offset": "float", + "read_out_rate": "float", + "read_out_rate_unit": "UnitsFrequency", + "voltage": "float", + "voltage_unit": "UnitsElectricPotential", + "zoom": "float" + }, + "emission_wavelength": "PositiveFloat", + "emission_wavelength_unit": "UnitsLength", + "excitation_wavelength": "PositiveFloat", + "excitation_wavelength_unit": "UnitsLength", + "filter_set_ref": { "id": "FilterSetID" }, + "fluor": "str", + "illumination_type": "IlluminationType", + "light_path": { + "annotation_ref": { "id": "AnnotationID" }, + "dichroic_ref": { "id": "DichroicID" }, + "emission_filter_ref": { "id": "FilterID" }, + "excitation_filter_ref": { "id": "FilterID" } + }, + "light_source_settings": { + "id": "LightSourceID", + "attenuation": "PercentFraction", + "wavelength": "PositiveFloat", + "wavelength_unit": "UnitsLength" + }, + "name": "str", + "nd_filter": "float", + "pinhole_size": "float", + "pinhole_size_unit": "UnitsLength", + "pockel_cell_setting": "int", + "samples_per_pixel": "PositiveInt" + }, + "interleaved": "bool", + "metadata_only": "bool", + "physical_size_x": "PositiveFloat", + "physical_size_x_unit": "UnitsLength", + "physical_size_y": "PositiveFloat", + "physical_size_y_unit": "UnitsLength", + "physical_size_z": "PositiveFloat", + "physical_size_z_unit": "UnitsLength", + "planes": { + "the_c": "NonNegativeInt", + "the_t": "NonNegativeInt", + "the_z": "NonNegativeInt", + "annotation_ref": { "id": "AnnotationID" }, + "delta_t": "float", + "delta_t_unit": "UnitsTime", + "exposure_time": "float", + "exposure_time_unit": "UnitsTime", + "hash_sha1": "Hex40", + "position_x": "float", + "position_x_unit": "UnitsLength", + "position_y": "float", + "position_y_unit": "UnitsLength", + "position_z": "float", + "position_z_unit": "UnitsLength" + }, + "significant_bits": "PositiveInt", + "tiff_data_blocks": { + "first_c": "NonNegativeInt", + "first_t": "NonNegativeInt", + "first_z": "NonNegativeInt", + "ifd": "NonNegativeInt", + "plane_count": "NonNegativeInt", + "uuid": { "file_name": "str", "value": "UniversallyUniqueIdentifier" } + }, + "time_increment": "float", + "time_increment_unit": "UnitsTime" + }, + "acquisition_date": "datetime", + "annotation_ref": { "id": "AnnotationID" }, + "description": "str", + "experiment_ref": { "id": "ExperimentID" }, + "experimenter_group_ref": { "id": "ExperimenterGroupID" }, + "experimenter_ref": { "id": "ExperimenterID" }, + "imaging_environment": { + "air_pressure": "float", + "air_pressure_unit": "UnitsPressure", + "co2_percent": "PercentFraction", + "humidity": "PercentFraction", + "map": { "m": { "value": "str", "k": "str" } }, + "temperature": "float", + "temperature_unit": "UnitsTemperature" + }, + "instrument_ref": { "id": "InstrumentID" }, + "microbeam_manipulation_ref": { "id": "MicrobeamManipulationID" }, + "name": "str", + "objective_settings": { + "id": "ObjectiveID", + "correction_collar": "float", + "medium": "Medium", + "refractive_index": "float" + }, + "roi_ref": { "id": "ROIID" }, + "stage_label": { + "name": "str", + "x": "float", + "x_unit": "UnitsLength", + "y": "float", + "y_unit": "UnitsLength", + "z": "float", + "z_unit": "UnitsLength" + } + }, + "instruments": { + "id": "InstrumentID", + "annotation_ref": { "id": "AnnotationID" }, + "detectors": { + "lot_number": "str", + "manufacturer": "str", + "model": "str", + "serial_number": "str", + "id": "DetectorID", + "amplification_gain": "float", + "annotation_ref": { "id": "AnnotationID" }, + "gain": "float", + "offset": "float", + "type": "Type", + "voltage": "float", + "voltage_unit": "UnitsElectricPotential", + "zoom": "float" + }, + "dichroics": { + "lot_number": "str", + "manufacturer": "str", + "model": "str", + "serial_number": "str", + "id": "DichroicID", + "annotation_ref": { "id": "AnnotationID" } + }, + "filter_sets": { + "lot_number": "str", + "manufacturer": "str", + "model": "str", + "serial_number": "str", + "id": "FilterSetID", + "dichroic_ref": { "id": "DichroicID" }, + "emission_filter_ref": { "id": "FilterID" }, + "excitation_filter_ref": { "id": "FilterID" } + }, + "filters": { + "lot_number": "str", + "manufacturer": "str", + "model": "str", + "serial_number": "str", + "id": "FilterID", + "annotation_ref": { "id": "AnnotationID" }, + "filter_wheel": "str", + "transmittance_range": { + "cut_in": "PositiveFloat", + "cut_in_tolerance": "NonNegativeFloat", + "cut_in_tolerance_unit": "UnitsLength", + "cut_in_unit": "UnitsLength", + "cut_out": "PositiveFloat", + "cut_out_tolerance": "NonNegativeFloat", + "cut_out_tolerance_unit": "UnitsLength", + "cut_out_unit": "UnitsLength", + "transmittance": "PercentFraction" + }, + "type": "Type" + }, + "light_source_group": "List[Union]", + "microscope": { + "lot_number": "str", + "manufacturer": "str", + "model": "str", + "serial_number": "str", + "type": "Type" + }, + "objectives": { + "lot_number": "str", + "manufacturer": "str", + "model": "str", + "serial_number": "str", + "id": "ObjectiveID", + "annotation_ref": { "id": "AnnotationID" }, + "calibrated_magnification": "float", + "correction": "Correction", + "immersion": "Immersion", + "iris": "bool", + "lens_na": "float", + "nominal_magnification": "float", + "working_distance": "float", + "working_distance_unit": "UnitsLength" + } + }, + "plates": { + "id": "PlateID", + "annotation_ref": { "id": "AnnotationID" }, + "column_naming_convention": "NamingConvention", + "columns": "PositiveInt", + "description": "str", + "external_identifier": "str", + "field_index": "NonNegativeInt", + "name": "str", + "plate_acquisitions": { + "id": "PlateAcquisitionID", + "annotation_ref": { "id": "AnnotationID" }, + "description": "str", + "end_time": "datetime", + "maximum_field_count": "PositiveInt", + "name": "str", + "start_time": "datetime", + "well_sample_ref": { "id": "WellSampleID" } + }, + "row_naming_convention": "NamingConvention", + "rows": "PositiveInt", + "status": "str", + "well_origin_x": "float", + "well_origin_x_unit": "UnitsLength", + "well_origin_y": "float", + "well_origin_y_unit": "UnitsLength", + "wells": { + "column": "NonNegativeInt", + "id": "WellID", + "row": "NonNegativeInt", + "annotation_ref": { "id": "AnnotationID" }, + "color": "Color", + "external_description": "str", + "external_identifier": "str", + "reagent_ref": { "id": "ReagentID" }, + "type": "str", + "well_samples": { + "id": "WellSampleID", + "index": "NonNegativeInt", + "image_ref": { "id": "ImageID" }, + "position_x": "float", + "position_x_unit": "UnitsLength", + "position_y": "float", + "position_y_unit": "UnitsLength", + "timepoint": "datetime" + } + } + }, + "projects": { + "id": "ProjectID", + "annotation_ref": { "id": "AnnotationID" }, + "dataset_ref": { "id": "DatasetID" }, + "description": "str", + "experimenter_group_ref": { "id": "ExperimenterGroupID" }, + "experimenter_ref": { "id": "ExperimenterID" }, + "name": "str" + }, + "rights": { "rights_held": "str", "rights_holder": "str" }, + "rois": { + "id": "ROIID", + "annotation_ref": { "id": "AnnotationID" }, + "description": "str", + "name": "str", + "union": "List[Union]" + }, + "screens": { + "id": "ScreenID", + "annotation_ref": { "id": "AnnotationID" }, + "description": "str", + "name": "str", + "plate_ref": { "id": "PlateID" }, + "protocol_description": "str", + "protocol_identifier": "str", + "reagent_set_description": "str", + "reagent_set_identifier": "str", + "reagents": { + "id": "ReagentID", + "annotation_ref": { "id": "AnnotationID" }, + "description": "str", + "name": "str", + "reagent_identifier": "str" + }, + "type": "str" + }, + "structured_annotations": { + "id": "AnnotationID", + "annotation_ref": { "id": "AnnotationID" }, + "annotator": "ExperimenterID", + "description": "str", + "namespace": "str" + }, + "uuid": "UniversallyUniqueIdentifier" +} diff --git a/tests/data/timestampannotation.ome.xml b/tests/data/timestampannotation.ome.xml index a30717f7..0f38b67e 100644 --- a/tests/data/timestampannotation.ome.xml +++ b/tests/data/timestampannotation.ome.xml @@ -174,6 +174,8 @@ Namespace="sample.openmicroscopy.org/time/romefire"> 0066-07-18T00:00:00 + diff --git a/tests/data/transfer.ome.xml b/tests/data/transfer.ome.xml new file mode 100644 index 00000000..9e044065 --- /dev/null +++ b/tests/data/transfer.ome.xml @@ -0,0 +1,70 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + this is a test to see if the kv pairs in omero have any length limits. I don't think they do, but I will write something relatively long just so I can double-check whether that is the case or not. + + + + simple_tag + + + root_0/2022-01/14/18-30-55.264/combined_result.tiff + + + root_0/2022-01/14/18-30-55.264/combined_result.tiff + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/tests/test_autogen.py b/tests/test_autogen.py index f3c90a57..dc871878 100644 --- a/tests/test_autogen.py +++ b/tests/test_autogen.py @@ -1,20 +1,53 @@ import importlib +import os import sys +from importlib.util import find_spec from pathlib import Path import pytest +from _pytest.monkeypatch import MonkeyPatch -ome_autogen = pytest.importorskip("ome_autogen") +from ome_types._mixins import _base_type -def test_autogen(tmp_path_factory): +@pytest.fixture +def imports_autogen(monkeypatch: MonkeyPatch) -> None: + """This fixture adds the src folder to sys.path so we can import ome_autogen. + + The goal here is to be able to run `pip install .[test,dev]` on CI, NOT in editable + mode, but still be able to test autogen without requiring it to be included in the + wheel. + """ + if not find_spec("ome_autogen"): + SRC = Path(__file__).parent.parent / "src" + assert (SRC / "ome_autogen").is_dir() + monkeypatch.syspath_prepend(str(SRC)) + + +@pytest.mark.skipif(not os.getenv("CI"), reason="slow") +@pytest.mark.usefixtures("imports_autogen") +def test_autogen(tmp_path: Path, monkeypatch: MonkeyPatch) -> None: """Test that autogen works without raising an exception. - This does *not* actually test the resulting model. + and that mypy has no issues with it. """ - target_dir = tmp_path_factory.mktemp("_ome_types_test_model") - xsd = Path(__file__).parent.parent / "src" / "ome_types" / "ome-2016-06.xsd" - ome_autogen.convert_schema(url=xsd, target_dir=target_dir) - sys.path.insert(0, str(target_dir.parent)) - assert importlib.import_module(target_dir.name) - sys.path.pop(0) + import ome_autogen.main + + ome_autogen.main.build_model(output_dir=tmp_path, do_formatting=True, do_mypy=True) + + monkeypatch.delitem(sys.modules, "ome_types") + monkeypatch.delitem(sys.modules, "ome_types._autogenerated") + monkeypatch.delitem(sys.modules, "ome_types._autogenerated.ome_2016_06") + monkeypatch.syspath_prepend(str(tmp_path)) + mod = importlib.import_module("ome_types._autogenerated.ome_2016_06") + assert mod.__file__ + assert mod.__file__.startswith(str(tmp_path)) + assert mod.Channel(color="blue") + + +@pytest.mark.usefixtures("imports_autogen") +def test_autosequence_name() -> None: + """These should match, but shouldn't be imported from each other.""" + from ome_autogen import _generator + + assert _generator.AUTO_SEQUENCE == _base_type.AUTO_SEQUENCE diff --git a/tests/test_ids.py b/tests/test_ids.py new file mode 100644 index 00000000..99e5e59b --- /dev/null +++ b/tests/test_ids.py @@ -0,0 +1,34 @@ +import pytest + +from ome_types import from_xml +from ome_types.model import Line, Rectangle + + +def test_shape_ids() -> None: + rect = Rectangle(x=0, y=0, width=1, height=1) + line = Line(x1=0, y1=0, x2=1, y2=1) + assert rect.id == "Shape:0" + assert line.id == "Shape:1" + + +def test_id_conversion() -> None: + """When converting ids, we should still be preserving references.""" + XML_WITH_BAD_REFS = """ + + + + + + + + + + """ + with pytest.warns(match="Casting invalid InstrumentID"): + ome = from_xml(XML_WITH_BAD_REFS) + + assert ome.instruments[0].id == "Instrument:0" + assert ome.images[0].instrument_ref is not None + assert ome.images[0].instrument_ref.id == "Instrument:0" + assert ome.images[0].instrument_ref.ref is ome.instruments[0] diff --git a/tests/test_invalid_schema.py b/tests/test_invalid_schema.py index ebead882..5887efc8 100644 --- a/tests/test_invalid_schema.py +++ b/tests/test_invalid_schema.py @@ -4,7 +4,8 @@ from ome_types import from_xml -DATA = Path(__file__).parent / "data" +TESTS = Path(__file__).parent +DATA = TESTS / "data" def test_bad_xml_annotation() -> None: @@ -12,4 +13,5 @@ def test_bad_xml_annotation() -> None: with pytest.warns(match="Casting invalid AnnotationID"): ome = from_xml(DATA / "invalid_xml_annotation.ome.xml") assert len(ome.images) == 1 + assert ome.structured_annotations assert ome.structured_annotations[0].id == "Annotation:0" diff --git a/tests/test_model.py b/tests/test_model.py index 5b5e1aec..a5d355cc 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -1,230 +1,304 @@ -import pickle +from __future__ import annotations + +import datetime import re +import sys +from functools import lru_cache from pathlib import Path -from unittest import mock +from typing import TYPE_CHECKING, cast from xml.dom import minidom -from xml.etree import ElementTree +from xml.etree import ElementTree as ET import pytest from pydantic import ValidationError -from xmlschema.validators.exceptions import XMLSchemaValidationError -import util +import ome_types from ome_types import from_tiff, from_xml, model, to_xml -from ome_types._xmlschema import NS_OME, URI_OME, get_schema, to_xml_element - -ValidationErrors = [ValidationError, XMLSchemaValidationError] -try: - from lxml.etree import XMLSchemaValidateError +from ome_types._conversion import _get_ome_type - ValidationErrors.append(XMLSchemaValidateError) -except ImportError: - pass +if TYPE_CHECKING: + import xmlschema + from _pytest.mark.structures import ParameterSet +TESTS = Path(__file__).parent +DATA = TESTS / "data" -SHOULD_FAIL_READ = { - # Some timestamps have negative years which datetime doesn't support. - "timestampannotation", -} -SHOULD_FAIL_VALIDATION = {"invalid_xml_annotation"} -SHOULD_RAISE_READ = {"bad"} +SHOULD_FAIL_VALIDATION = {"invalid_xml_annotation", "bad"} SHOULD_FAIL_ROUNDTRIP = { # Order of elements in StructuredAnnotations and Union are jumbled. "timestampannotation-posix-only", "transformations-downgrade", "invalid_xml_annotation", } -SHOULD_FAIL_ROUNDTRIP_LXML = { - "folders-simple-taxonomy", - "folders-larger-taxonomy", -} + SKIP_ROUNDTRIP = { # These have XMLAnnotations with extra namespaces and mixed content, which # the automated round-trip test code doesn't properly verify yet. So even # though these files do appear to round-trip correctly when checked by eye, # we'll play it safe and skip them until the test is fixed. "spim", - "xmlannotation-body-space", "xmlannotation-multi-value", "xmlannotation-svg", } - -def mark_xfail(fname): - return pytest.param( - fname, - marks=pytest.mark.xfail( - strict=True, reason="Unexpected success. You fixed it!" - ), - ) - - -def mark_skip(fname): - return pytest.param(fname, marks=pytest.mark.skip) +URI_OME = "http://www.openmicroscopy.org/Schemas/OME/2016-06" +SCHEMA_LOCATION = "{http://www.w3.org/2001/XMLSchema-instance}schemaLocation" +NS_OME = "{" + URI_OME + "}" +OME_2016_06_XSD = str(Path(ome_types.__file__).parent / "ome-2016-06.xsd") -def true_stem(p): +def true_stem(p: Path) -> str: return p.name.partition(".")[0] -all_xml = list((Path(__file__).parent / "data").glob("*.ome.xml")) -xml_read = [mark_xfail(f) if true_stem(f) in SHOULD_FAIL_READ else f for f in all_xml] -xml_roundtrip = [] +all_xml = list(DATA.glob("*.ome.xml")) +xml_roundtrip: list[Path | ParameterSet] = [] for f in all_xml: stem = true_stem(f) - if stem in SHOULD_FAIL_READ | SHOULD_RAISE_READ: - continue - elif stem in SHOULD_FAIL_ROUNDTRIP: - f = mark_xfail(f) + if stem in SHOULD_FAIL_ROUNDTRIP: + mrk = pytest.mark.xfail(strict=True, reason="Unexpected success. You fixed it!") + f = pytest.param(f, marks=mrk) # type: ignore elif stem in SKIP_ROUNDTRIP: - f = mark_skip(f) + f = pytest.param(f, marks=pytest.mark.skip) # type: ignore xml_roundtrip.append(f) -validate = [True, False] +validate = [False] -parser = ["lxml", "xmlschema"] + +@pytest.mark.parametrize("validate", validate) +def test_from_valid_xml(valid_xml: Path, validate: bool) -> None: + ome = from_xml(valid_xml, validate=validate) + assert ome + assert repr(ome) -@pytest.mark.parametrize("xml", xml_read, ids=true_stem) -@pytest.mark.parametrize("parser", parser) @pytest.mark.parametrize("validate", validate) -def test_from_xml(xml, parser: str, validate: bool, benchmark): - should_raise = SHOULD_RAISE_READ.union(SHOULD_FAIL_VALIDATION if validate else []) - if true_stem(xml) in should_raise: - with pytest.raises(tuple(ValidationErrors)): - assert benchmark(from_xml, xml, parser=parser, validate=validate) +def test_from_invalid_xml(invalid_xml: Path, validate: bool) -> None: + if validate: + with pytest.raises(ValidationError): + from_xml(invalid_xml, validate=validate) else: - assert benchmark(from_xml, xml, parser=parser, validate=validate) + with pytest.warns(): + from_xml(invalid_xml, validate=validate) -@pytest.mark.parametrize("parser", parser) @pytest.mark.parametrize("validate", validate) -def test_from_tiff(benchmark, validate, parser): +def test_from_tiff(validate: bool) -> None: """Test that OME metadata extractions from Tiff headers works.""" - _path = Path(__file__).parent / "data" / "ome.tiff" - ome = benchmark(from_tiff, _path, parser=parser, validate=validate) + _path = DATA / "ome.tiff" + ome = from_tiff(_path, validate=validate) assert len(ome.images) == 1 assert ome.images[0].id == "Image:0" assert ome.images[0].pixels.size_x == 6 assert ome.images[0].pixels.channels[0].samples_per_pixel == 1 -@pytest.mark.parametrize("xml", xml_roundtrip, ids=true_stem) -@pytest.mark.parametrize("parser", parser) -@pytest.mark.parametrize("validate", validate) -def test_roundtrip(xml, parser, validate, benchmark): - """Ensure we can losslessly round-trip XML through the model and back.""" - xml = str(xml) - schema = get_schema(xml) - - def canonicalize(xml, strip_empty): - d = schema.decode(xml, use_defaults=True) - # Strip extra whitespace in the schemaLocation value. - d["@xsi:schemaLocation"] = re.sub(r"\s+", " ", d["@xsi:schemaLocation"]) - root = schema.encode(d, path=NS_OME + "OME", use_defaults=True) - # These are the tags that appear in the example files with empty - # content. Since our round-trip will drop empty elements, we'll need to - # strip them from the "original" documents before comparison. - if strip_empty: - for tag in ("Description", "LightPath", "Map"): - for e in root.findall(f".//{NS_OME}{tag}[.='']..."): - e.remove(e.find(f"{NS_OME}{tag}")) - # ET.canonicalize can't handle an empty namespace so we need to - # re-register the OME namespace with an actual name before calling - # tostring. - ElementTree.register_namespace("ome", URI_OME) - xml_out = ElementTree.tostring(root, "unicode") - xml_out = util.canonicalize(xml_out, strip_text=True) - xml_out = minidom.parseString(xml_out).toprettyxml(indent=" ") - return xml_out +def test_roundtrip_inverse(valid_xml: Path, tmp_path: Path) -> None: + """both variants have been touched by the model, here...""" + ome1 = from_xml(valid_xml) - original = canonicalize(xml, True) - ome = from_xml(xml, parser=parser, validate=validate) - rexml = benchmark(to_xml, ome) + # FIXME: + # there is a small difference in the XML output when using xml instead of lxml + # that makes the text of an xml annotation in `xmlannotation-multi-value` be + # 'B\n ' instead of 'B'. + # we should investigate this and fix it, but here we just use indent=0 to avoid it. + xml = to_xml(ome1, indent=0) + out = tmp_path / "test.xml" + out.write_bytes(xml.encode()) + ome2 = from_xml(out) - try: - assert canonicalize(rexml, False) == original - except AssertionError: - # Special xfail catch since two files fail only with xml2dict - if true_stem(Path(xml)) in SHOULD_FAIL_ROUNDTRIP_LXML and parser == "lxml": - pytest.xfail( - f"Expected failure on roundtrip using xml2dict on file: {stem}" - ) - else: - raise + assert ome1 == ome2 -@pytest.mark.parametrize("parser", parser) -@pytest.mark.parametrize("validate", validate) -def test_to_xml_with_kwargs(validate, parser): - """Ensure kwargs are passed to ElementTree""" - ome = from_xml( - Path(__file__).parent / "data" / "example.ome.xml", - parser=parser, - validate=validate, - ) +# @pytest.mark.parametrize("validate", validate) +# def test_to_xml_with_kwargs(validate): +# """Ensure kwargs are passed to ElementTree""" +# ome = from_xml(DATA / "example.ome.xml", validate=validate) - with mock.patch("xml.etree.ElementTree.tostring") as mocked_et_tostring: - element = to_xml_element(ome) - # Use an ElementTree.tostring kwarg and assert that it was passed through - to_xml(element, xml_declaration=True) - assert mocked_et_tostring.call_args.xml_declaration +# with mock.patch("xml.etree.ElementTree.tostring") as mocked_et_tostring: +# element = to_xml_element(ome) +# # Use an ElementTree.tostring kwarg and assert that it was passed through +# to_xml(element, xml_declaration=True) +# assert mocked_et_tostring.call_args.xml_declaration -@pytest.mark.parametrize("xml", xml_read, ids=true_stem) -@pytest.mark.parametrize("parser", parser) -@pytest.mark.parametrize("validate", validate) -def test_serialization(xml, validate, parser): - """Test pickle serialization and reserialization.""" - if true_stem(xml) in SHOULD_RAISE_READ: - pytest.skip("Can't pickle unreadable xml") - if validate and true_stem(xml) in SHOULD_FAIL_VALIDATION: - pytest.skip("Can't pickle invalid xml with validate=True") - - ome = from_xml(xml, parser=parser, validate=validate) - serialized = pickle.dumps(ome) - deserialized = pickle.loads(serialized) - assert ome == deserialized - - -def test_no_id(): +def test_no_id() -> None: """Test that ids are optional, and auto-increment.""" - i = model.Instrument(id=20) + i = model.Instrument(id=20) # type: ignore assert i.id == "Instrument:20" - i2 = model.Instrument() + i2 = model.Instrument() # type: ignore assert i2.id == "Instrument:21" # but validation still works - with pytest.raises(ValueError): + with pytest.warns(match="Casting invalid InstrumentID"): model.Instrument(id="nonsense") -def test_required_missing(): +def test_required_missing() -> None: """Test subclasses with non-default arguments still work.""" - with pytest.raises(ValidationError) as e: - _ = model.BooleanAnnotation() - assert "1 validation error for BooleanAnnotation" in str(e.value) - assert "value\n field required" in str(e.value) + with pytest.raises(ValidationError, match="value\n field required"): + model.BooleanAnnotation() # type: ignore - with pytest.raises(ValidationError) as e: - _ = model.Label() - assert "2 validation errors for Label" in str(e.value) - assert "x\n field required" in str(e.value) - assert "y\n field required" in str(e.value) + with pytest.raises(ValidationError, match="x\n field required"): + model.Label() # type: ignore -@pytest.mark.parametrize("parser", parser) -@pytest.mark.parametrize("validate", validate) -def test_refs(validate, parser): - xml = Path(__file__).parent / "data" / "two-screens-two-plates-four-wells.ome.xml" - ome = from_xml(xml, parser=parser, validate=validate) - assert ome.screens[0].plate_ref[0].ref is ome.plates[0] +def test_refs() -> None: + xml = DATA / "two-screens-two-plates-four-wells.ome.xml" + ome = from_xml(xml) + assert ome.screens[0].plate_refs[0].ref is ome.plates[0] -@pytest.mark.parametrize("validate", validate) -@pytest.mark.parametrize("parser", parser) -def test_with_ome_ns(validate, parser): - xml = Path(__file__).parent / "data" / "ome_ns.ome.xml" - ome = from_xml(xml, parser=parser, validate=validate) - assert ome.experimenters +def test_with_ome_ns() -> None: + assert from_xml(DATA / "ome_ns.ome.xml").experimenters + + +def test_get_ome_type() -> None: + t = _get_ome_type(f'') + assert t is model.Image + + with pytest.raises(ValueError): + _get_ome_type("") + + # this can be used to instantiate XML with a non OME root type: + project = from_xml(f'') + assert isinstance(project, model.Project) + + +@pytest.mark.skipif(sys.version_info < (3, 8), reason="requires python3.8 or higher") +def test_roundtrip(valid_xml: Path) -> None: + """Ensure we can losslessly round-trip XML through the model and back.""" + if true_stem(valid_xml) in SKIP_ROUNDTRIP: + pytest.xfail("known issues with canonicalization") + + original = _canonicalize(valid_xml.read_bytes()) + + ome = from_xml(valid_xml) + rexml = to_xml(ome) + new = _canonicalize(rexml) + if new != original: + Path("original.xml").write_text(original) + Path("rewritten.xml").write_text(new) + raise AssertionError + + +# ########## Canonicalization utils for testing ########## + + +def _canonicalize(xml: str | bytes, pretty: bool = False) -> str: + ET.register_namespace("ome", URI_OME) + + # The only reason we're using xmlschema at this point is because + # it converts floats properly CutIn="550" -> CutIn="550.0" based on the schema + # once that is fixed, we can remove xmlschema entirely + schema = _get_schema() + decoded = schema.decode(xml) + root = cast(ET.Element, schema.encode(decoded, path=f"{NS_OME}OME")) + + # Strip extra whitespace in the schemaLocation value. + root.attrib[SCHEMA_LOCATION] = re.sub(r"\s+", " ", root.attrib[SCHEMA_LOCATION]) + + # sorting elements actually breaks the validity of some documents, + # but it's useful for comparison sake. + _sort_elements(root) + xml_out = ET.tostring(root, "unicode") + xml_out = ET.canonicalize(xml_out, strip_text=True) + if pretty: + # totally optional for comparison sake... but nice for debugging + xml_out = minidom.parseString(xml_out).toprettyxml(indent=" ") + return xml_out + + +@lru_cache(maxsize=None) +def _get_schema() -> xmlschema.XMLSchemaBase: + xmlschema = pytest.importorskip("xmlschema") + + schema = xmlschema.XMLSchema(OME_2016_06_XSD) + # FIXME Hack to work around xmlschema poor support for keyrefs to + # substitution groups. This can be removed, if decode(validation='skip') is used. + ls_sgs = schema.maps.substitution_groups[f"{NS_OME}LightSourceGroup"] + ls_id_maps = schema.maps.identities[f"{NS_OME}LightSourceIDKey"] + ls_id_maps.elements = {e: None for e in ls_sgs} + return schema + + +def _sort_elements(element: ET.Element, recursive: bool = True) -> None: + # Replace the existing child elements with the sorted ones + element[:] = sorted(element, key=lambda child: child.tag) + + if recursive: + # Recursively sort child elements for each subelement + for child in element: + _sort_elements(child) + + +def test_datetimes() -> None: + now = datetime.datetime.now() + anno = model.TimestampAnnotation(value=now) + assert anno.value == now + anno = model.TimestampAnnotation(value="0066-07-18T00:00:00") + assert anno.value == datetime.datetime(66, 7, 18) + + XML = """ + + -231400000-01-01T00:00:00 + + """ + with pytest.warns(match="Invalid datetime.*BC dates are not supported"): + from_xml(XML) + + +@pytest.mark.parametrize("only", [True, False, {}, None]) +def test_metadata_only(only: bool) -> None: + pix = model.Pixels( + metadata_only=only, # passing bool should be fine + size_c=1, + size_t=1, + size_x=1, + size_y=1, + size_z=1, + dimension_order="XYZCT", + type="uint8", + ) + if only not in (False, None): # note that empty dict is "truthy" for metadata_only + assert pix.metadata_only + else: + assert not pix.metadata_only + + +def test_deepcopy() -> None: + from copy import deepcopy + + ome = from_xml(DATA / "example.ome.xml") + newome = deepcopy(ome) + + assert ome == newome + assert ome is not newome + + +def test_structured_annotations() -> None: + long = model.LongAnnotation(value=1) + annotations = [model.CommentAnnotation(value="test comment"), long] + ome = model.OME(structured_annotations=annotations) + assert ome + assert len(ome.structured_annotations) == 2 + assert "Long" in ome.to_xml() + ome.structured_annotations.remove(long) + assert "Long" not in ome.to_xml() + + assert list(ome.structured_annotations) == ome.structured_annotations + + +def test_colors() -> None: + from ome_types.model.simple_types import Color + + shape = model.Shape(fill_color="red", stroke_color="blue") + assert isinstance(shape.fill_color, Color) + assert isinstance(shape.stroke_color, Color) + assert shape.fill_color.as_rgb_tuple() == (255, 0, 0) + assert shape.stroke_color.as_named() == "blue" + + assert model.Shape().fill_color is None + assert model.Shape().stroke_color is None diff --git a/tests/test_names.py b/tests/test_names.py new file mode 100644 index 00000000..a61baf26 --- /dev/null +++ b/tests/test_names.py @@ -0,0 +1,252 @@ +from __future__ import annotations + +import json +from pathlib import Path +from typing import Any, Sequence + +import pytest +from pydantic import BaseModel +from pydantic.typing import display_as_type + +import ome_types + +TESTS = Path(__file__).parent +KNOWN_CHANGES: dict[str, list[tuple[str, str | None]]] = { + "OME.datasets": [ + ("annotation_ref", "annotation_refs"), + ("image_ref", "image_refs"), + ], + "OME.experimenter_groups": [ + ("annotation_ref", "annotation_refs"), + ("experimenter_ref", "experimenter_refs"), + ("leader", "leaders"), + ], + "OME.experimenters": [("annotation_ref", "annotation_refs")], + "OME.experiments.microbeam_manipulations": [ + ("roi_ref", "roi_refs"), + ("light_source_settings", "light_source_settings_combinations"), + ], + "OME.folders": [ + ("annotation_ref", "annotation_refs"), + ("folder_ref", "folder_refs"), + ("image_ref", "image_refs"), + ("roi_ref", "roi_refs"), + ], + "OME.images.pixels": [("bin_data", "bin_data_blocks")], + "OME.images.pixels.channels": [("annotation_ref", "annotation_refs")], + "OME.images.pixels.channels.light_path": [ + ("annotation_ref", "annotation_refs"), + ("emission_filter_ref", "emission_filters"), + ("excitation_filter_ref", "excitation_filters"), + ], + "OME.images.pixels.planes": [("annotation_ref", "annotation_refs")], + "OME.images": [ + ("annotation_ref", "annotation_refs"), + ("microbeam_manipulation_ref", "microbeam_manipulation_refs"), + ("roi_ref", "roi_refs"), + ], + "OME.images.imaging_environment.map": [("m", "ms")], + "OME.instruments": [ + ("annotation_ref", "annotation_refs"), + ("light_source_group", None), + ], + "OME.instruments.detectors": [("annotation_ref", "annotation_refs")], + "OME.instruments.dichroics": [("annotation_ref", "annotation_refs")], + "OME.instruments.filter_sets": [ + ("emission_filter_ref", "emission_filters"), + ("excitation_filter_ref", "excitation_filters"), + ], + "OME.instruments.filters": [("annotation_ref", "annotation_refs")], + "OME.instruments.objectives": [("annotation_ref", "annotation_refs")], + "OME.plates": [("annotation_ref", "annotation_refs")], + "OME.plates.plate_acquisitions": [ + ("annotation_ref", "annotation_refs"), + ("well_sample_ref", "well_sample_refs"), + ], + "OME.plates.wells": [("annotation_ref", "annotation_refs")], + "OME.projects": [ + ("annotation_ref", "annotation_refs"), + ("dataset_ref", "dataset_refs"), + ], + "OME.rois": [("annotation_ref", "annotation_refs")], + "OME.screens": [("annotation_ref", "annotation_refs"), ("plate_ref", "plate_refs")], + "OME.screens.reagents": [("annotation_ref", "annotation_refs")], + # OME.structured_annotations went from + # List[Annotation] -> Optional[StructuredAnnotations] + # this is the main breaking change. + "OME.structured_annotations": [ + ("annotation_ref", None), + ("id", None), + ("annotator", None), + ("description", None), + ("namespace", None), + ], +} + + +def _assert_names_match( + old: dict[str, Any], new: dict[str, Any], path: Sequence[str] = () +) -> None: + """Make sure every key in old is in new, or that it's in KNOWN_CHANGES.""" + for old_key, value in old.items(): + new_key = old_key + if old_key not in new: + _path = ".".join(path) + if _path in KNOWN_CHANGES: + for from_, new_key in KNOWN_CHANGES[_path]: # type: ignore + if old_key == from_ and (new_key in new or new_key is None): + break + else: + raise AssertionError( + f"Key {old_key!r} not in new model at {_path}: {list(new)}" + ) + else: + raise AssertionError(f"{_path!r} not in KNOWN_CHANGES") + + if isinstance(value, dict) and new_key in new: + _assert_names_match(value, new[new_key], (*path, old_key)) + + +def _get_fields(cls: type[BaseModel]) -> dict[str, Any]: + fields = {} + for name, field in cls.__fields__.items(): + if name.startswith("_"): + continue + if isinstance(field.type_, type) and issubclass(field.type_, BaseModel): + fields[name] = _get_fields(field.type_) + else: + fields[name] = display_as_type(field.outer_type_) # type: ignore + return fields + + +def test_names() -> None: + with (TESTS / "data" / "old_model.json").open() as f: + old_names = json.load(f) + new_names = _get_fields(ome_types.model.OME) + _assert_names_match(old_names, new_names, ("OME",)) + + +V1_EXPORTS = [ + ("affine_transform", "AffineTransform"), + ("annotation", "Annotation"), + ("annotation_ref", "AnnotationRef"), + ("arc", "Arc"), + ("basic_annotation", "BasicAnnotation"), + ("bin_data", "BinData"), + ("binary_file", "BinaryFile"), + ("boolean_annotation", "BooleanAnnotation"), + ("channel", "Channel"), + ("channel_ref", "ChannelRef"), + ("comment_annotation", "CommentAnnotation"), + ("dataset", "Dataset"), + ("dataset_ref", "DatasetRef"), + ("detector", "Detector"), + ("detector_settings", "DetectorSettings"), + ("dichroic", "Dichroic"), + ("dichroic_ref", "DichroicRef"), + ("double_annotation", "DoubleAnnotation"), + ("ellipse", "Ellipse"), + ("experiment", "Experiment"), + ("experiment_ref", "ExperimentRef"), + ("experimenter", "Experimenter"), + ("experimenter_group", "ExperimenterGroup"), + ("experimenter_group_ref", "ExperimenterGroupRef"), + ("experimenter_ref", "ExperimenterRef"), + ("external", "External"), + ("filament", "Filament"), + ("file_annotation", "FileAnnotation"), + ("filter", "Filter"), + ("filter_ref", "FilterRef"), + ("filter_set", "FilterSet"), + ("filter_set_ref", "FilterSetRef"), + ("folder", "Folder"), + ("folder_ref", "FolderRef"), + ("generic_excitation_source", "GenericExcitationSource"), + ("image", "Image"), + ("image_ref", "ImageRef"), + ("imaging_environment", "ImagingEnvironment"), + ("instrument", "Instrument"), + ("instrument_ref", "InstrumentRef"), + ("label", "Label"), + ("laser", "Laser"), + ("leader", "Leader"), + ("light_emitting_diode", "LightEmittingDiode"), + ("light_path", "LightPath"), + ("light_source", "LightSource"), + # ("light_source_group", "LightSourceGroup"), + ("light_source_settings", "LightSourceSettings"), + ("line", "Line"), + ("list_annotation", "ListAnnotation"), + ("long_annotation", "LongAnnotation"), + ("manufacturer_spec", "ManufacturerSpec"), + ("map", "Map"), + ("map_annotation", "MapAnnotation"), + ("mask", "Mask"), + ("microbeam_manipulation", "MicrobeamManipulation"), + ("microbeam_manipulation_ref", "MicrobeamManipulationRef"), + ("microscope", "Microscope"), + ("numeric_annotation", "NumericAnnotation"), + ("objective", "Objective"), + ("objective_settings", "ObjectiveSettings"), + ("ome", "OME"), + ("pixels", "Pixels"), + ("plane", "Plane"), + ("plate", "Plate"), + ("plate_acquisition", "PlateAcquisition"), + ("point", "Point"), + ("polygon", "Polygon"), + ("polyline", "Polyline"), + ("project", "Project"), + ("project_ref", "ProjectRef"), + ("pump", "Pump"), + ("reagent", "Reagent"), + ("reagent_ref", "ReagentRef"), + ("rectangle", "Rectangle"), + ("reference", "Reference"), + ("rights", "Rights"), + ("roi", "ROI"), + ("roi_ref", "ROIRef"), + ("screen", "Screen"), + ("settings", "Settings"), + ("shape", "Shape"), + # ("shape_group", "ShapeGroup"), + ("stage_label", "StageLabel"), + ("structured_annotations", "StructuredAnnotations"), + ("tag_annotation", "TagAnnotation"), + ("term_annotation", "TermAnnotation"), + ("text_annotation", "TextAnnotation"), + ("tiff_data", "TiffData"), + ("timestamp_annotation", "TimestampAnnotation"), + ("transmittance_range", "TransmittanceRange"), + ("type_annotation", "TypeAnnotation"), + ("well", "Well"), + ("well_sample", "WellSample"), + ("well_sample_ref", "WellSampleRef"), + ("xml_annotation", "XMLAnnotation"), +] + + +@pytest.mark.parametrize("name,cls_name", V1_EXPORTS) +def test_model_imports(name: str, cls_name: str) -> None: + from importlib import import_module + + # with pytest.warns(UserWarning, match="Importing submodules from ome_types.model"): + mod = import_module(f"ome_types.model.{name}") + + cls = getattr(mod, cls_name) + assert cls is not None + real_module = mod = import_module(f"ome_types._autogenerated.ome_2016_06.{name}") + + # modules and object must have the same id! This is importat for pickle + assert real_module is mod + assert getattr(real_module, cls_name) is cls + + +def test_deprecated_attrs() -> None: + ome = ome_types.from_xml(TESTS / "data" / "instrument-units-default.ome.xml") + with pytest.warns( + match="Attribute 'FilterSet.excitation_filter_ref' is " + "deprecated, use 'excitation_filters'" + ): + ref1 = ome.instruments[0].filter_sets[0].excitation_filter_ref + assert ref1 is ome.instruments[0].filter_sets[0].excitation_filters diff --git a/tests/test_omero_cli.py b/tests/test_omero_cli.py new file mode 100644 index 00000000..98d7bfa0 --- /dev/null +++ b/tests/test_omero_cli.py @@ -0,0 +1,119 @@ +from __future__ import annotations + +import os +import sys +from typing import TYPE_CHECKING +from unittest.mock import MagicMock + +import pytest + +from ome_types import OME, model + +if TYPE_CHECKING: + from pathlib import Path + + from omero.gateway import BlitzGateway + from pytest import MonkeyPatch, TempPathFactory + + +# this test can be run with only `pip install omero-cli-transfer --no-deps` +@pytest.mark.filterwarnings("ignore::DeprecationWarning") +def test_populate_omero(monkeypatch: MonkeyPatch) -> None: + monkeypatch.setitem(sys.modules, "omero.gateway", MagicMock()) + monkeypatch.setitem(sys.modules, "omero.rtypes", MagicMock()) + monkeypatch.setitem(sys.modules, "omero.model", MagicMock()) + monkeypatch.setitem(sys.modules, "omero.sys", MagicMock()) + monkeypatch.setitem(sys.modules, "ezomero", MagicMock()) + + gen_omero = pytest.importorskip("generate_omero_objects") + + conn = MagicMock() + getId = conn.getUpdateService.return_value.saveAndReturnObject.return_value.getId + getId.return_value.val = 2 + + ann = model.CommentAnnotation(value="test comment", id="Annotation:-123") + plate = model.Plate( + name="MyPlate", annotation_refs=[model.AnnotationRef(id=ann.id)] + ) + img = model.Image( + name="MyImage", + pixels=model.Pixels( + size_c=1, + size_t=1, + size_z=10, + size_x=100, + size_y=100, + dimension_order="XYZCT", + type="uint8", + ), + ) + project = model.Project(name="MyProject", description="Project description") + dataset = model.Dataset(name="MyDataset", image_refs=[model.ImageRef(id=img.id)]) + ome = OME( + images=[img], + plates=[plate], + structured_annotations=[ann], + projects=[project], + datasets=[dataset], + ) + gen_omero.populate_omero( + ome, img_map={}, conn=conn, hash="somehash", folder="", metadata=[] + ) + assert conn.method_calls + + +@pytest.fixture(scope="session") +def data_dir(tmp_path_factory: TempPathFactory) -> Path: + return tmp_path_factory.mktemp("data") + + +@pytest.fixture(scope="session") +def conn() -> BlitzGateway: + pytest.importorskip("omero.gateway") + + from omero.gateway import BlitzGateway + + user = os.environ["OMERO_USER"] + passwd = os.environ["OMERO_PASSWORD"] + host = os.environ["OMERO_HOST"] + conn = BlitzGateway(user, passwd, host=host, port=4064) + conn.connect() + try: + assert conn.isConnected() + yield conn + finally: + conn.close() + + +# To run this test, you must have omero-py installed and have the following +# environment variables set: +# OMERO_USER +# OMERO_PASSWORD +# OMERO_HOST +@pytest.mark.skipif("OMERO_USER" not in os.environ, reason="OMERO_USER not set") +@pytest.mark.parametrize( + "datatype, id", + [("Dataset", 21157), ("Project", 5414), ("Image", 1110952)], +) +@pytest.mark.filterwarnings("ignore::DeprecationWarning") +def test_populate_xml( + data_dir: Path, + datatype: str, + id: int, + conn: BlitzGateway, +) -> None: + from omero_cli_transfer import populate_xml + + dest = data_dir / "new.ome.xml" + ome, _ = populate_xml( + datatype=datatype, + id=id, + filepath=str(dest), + conn=conn, + hostname="host", + barchive=False, # write the file + metadata=[], + ) + assert isinstance(ome, OME) + assert dest.exists() + assert isinstance(OME.from_xml(str(dest)), OME) diff --git a/tests/test_paquo.py b/tests/test_paquo.py new file mode 100644 index 00000000..ca00a991 --- /dev/null +++ b/tests/test_paquo.py @@ -0,0 +1,33 @@ +import os +from pathlib import Path + +import pytest + +from ome_types import validate_xml + +# to run this test locally, you can download QuPath.app as follows: +# python -m paquo get_qupath --install-path ./qupath/apps --download-path ./qupath/download 0.4.3 # noqa: E501 +# (this is done automatically on CI) +if "PAQUO_QUPATH_DIR" not in os.environ: + qupath_apps = Path(__file__).parent.parent / "qupath" / "apps" + app_path = next(qupath_apps.glob("QuPath-*.app"), None) + if app_path is not None: + os.environ["PAQUO_QUPATH_DIR"] = str(app_path) + +try: + import shapely.geometry + from paquo.hierarchy import QuPathPathObjectHierarchy +except (ValueError, ImportError): + pytest.skip("Paquo not installed", allow_module_level=True) + + +@pytest.mark.filterwarnings("ignore:Importing submodules from ome_types.model") +def test_to_ome_xml() -> None: + h = QuPathPathObjectHierarchy() + h.add_annotation(roi=shapely.geometry.Point(1, 2)) + h.add_annotation(roi=shapely.geometry.LineString([(0, 0), (1, 1)])) + h.add_annotation(roi=shapely.geometry.LinearRing([(0, 0), (1, 1), (2, 2)])) + h.add_annotation(roi=shapely.geometry.box(0, 0, 1, 1)) + h.add_annotation(roi=shapely.geometry.Polygon([(0, 0), (1, 0), (2, 1), (0, 5)])) + xml = h.to_ome_xml() + validate_xml(xml) diff --git a/tests/test_serialization.py b/tests/test_serialization.py index 089348fb..fcc56dc3 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -1,5 +1,9 @@ +import pickle +from pathlib import Path + import pytest +from ome_types import from_xml from ome_types.model import OME, Channel, Image, Pixels @@ -16,7 +20,6 @@ def test_color_unset(channel_kwargs: dict) -> None: size_z=1, dimension_order="XYZTC", type="uint16", - metadata_only=True, channels=[Channel(**channel_kwargs)], ) ) @@ -24,3 +27,11 @@ def test_color_unset(channel_kwargs: dict) -> None: ) assert ("Color" in ome.to_xml()) is bool(channel_kwargs) + + +def test_serialization(valid_xml: Path) -> None: + """Test pickle serialization and reserialization.""" + ome = from_xml(valid_xml) + serialized = pickle.dumps(ome) + deserialized = pickle.loads(serialized) + assert ome == deserialized diff --git a/tests/test_units.py b/tests/test_units.py index cb4fff6d..b8b88c35 100644 --- a/tests/test_units.py +++ b/tests/test_units.py @@ -2,8 +2,8 @@ from pint import DimensionalityError from pydantic import ValidationError -from ome_types._units import ureg from ome_types.model import Channel, Laser, Plane, simple_types +from ome_types.units import ureg def test_quantity_math(): @@ -64,11 +64,9 @@ def test_reference_frame(): assert position_x.check("[length]") -def test_all_units(): +def test_all_units() -> None: """Test that all Unit* enums are in the registry.""" - for t in dir(simple_types): - if not t.startswith("Unit"): + for name, obj in vars(simple_types).items(): + if not name.startswith("Unit"): continue - e = getattr(simple_types, t) - for v in e.__members__.values(): - assert v.value.replace(" ", "_") in ureg + assert all(m.value.replace(" ", "_") in ureg for m in obj) diff --git a/tests/test_validation.py b/tests/test_validation.py new file mode 100644 index 00000000..06dc8d59 --- /dev/null +++ b/tests/test_validation.py @@ -0,0 +1,41 @@ +from __future__ import annotations + +import contextlib +from typing import TYPE_CHECKING, Callable + +import pytest + +from ome_types import validate_xml, validation + +if TYPE_CHECKING: + from pathlib import Path + +VALIDATORS: dict[str, Callable] = {} +with contextlib.suppress(ImportError): + import lxml # noqa: F401 + + VALIDATORS["lxml"] = validation.validate_xml_with_lxml + +with contextlib.suppress(ImportError): + import xmlschema # noqa: F401 + + VALIDATORS["xmlschema"] = validation.validate_xml_with_xmlschema + + +@pytest.mark.parametrize("backend", VALIDATORS) +def test_validation_good(valid_xml: Path, backend: str) -> None: + VALIDATORS[backend](valid_xml) + + +def test_validation_anybackend(single_xml: Path) -> None: + if VALIDATORS: + validate_xml(single_xml) + else: + with pytest.raises(ImportError): + validate_xml(single_xml) + + +@pytest.mark.parametrize("backend", VALIDATORS) +def test_validation_raises(invalid_xml: Path, backend: str) -> None: + with pytest.raises(validation.ValidationError): + VALIDATORS[backend](invalid_xml) diff --git a/tests/test_widget.py b/tests/test_widget.py index d00298e8..843b8b27 100644 --- a/tests/test_widget.py +++ b/tests/test_widget.py @@ -4,7 +4,8 @@ nplg = pytest.importorskip("ome_types._napari_plugin") -DATA = Path(__file__).parent / "data" +TESTS = Path(__file__).parent +DATA = TESTS / "data" @pytest.mark.parametrize("fname", DATA.iterdir(), ids=lambda x: x.stem) diff --git a/tests/util.py b/tests/util.py deleted file mode 100644 index c9e964fc..00000000 --- a/tests/util.py +++ /dev/null @@ -1,350 +0,0 @@ -import io -import re -from xml.etree import ElementTree - -# -------------------------------------------------------------------- -# Taken from Python 3.8's ElementTree.py. - - -def canonicalize(xml_data=None, *, out=None, from_file=None, **options): - """Convert XML to its C14N 2.0 serialised form. - - If *out* is provided, it must be a file or file-like object that receives - the serialised canonical XML output (text, not bytes) through its ``.write()`` - method. To write to a file, open it in text mode with encoding "utf-8". - If *out* is not provided, this function returns the output as text string. - - Either *xml_data* (an XML string) or *from_file* (a file path or - file-like object) must be provided as input. - - The configuration options are the same as for the ``C14NWriterTarget``. - """ - if xml_data is None and from_file is None: - raise ValueError("Either 'xml_data' or 'from_file' must be provided as input") - sio = None - if out is None: - sio = out = io.StringIO() - - parser = ElementTree.XMLParser(target=C14NWriterTarget(out.write, **options)) - - if xml_data is not None: - parser.feed(xml_data) - parser.close() - elif from_file is not None: - ElementTree.parse(from_file, parser=parser) - - return sio.getvalue() if sio is not None else None - - -_looks_like_prefix_name = re.compile(r"^\w+:\w+$", re.UNICODE).match - - -class C14NWriterTarget: - """ - Canonicalization writer target for the XMLParser. - - Serialises parse events to XML C14N 2.0. - - The *write* function is used for writing out the resulting data stream - as text (not bytes). To write to a file, open it in text mode with encoding - "utf-8" and pass its ``.write`` method. - - Configuration options: - - - *with_comments*: set to true to include comments - - *strip_text*: set to true to strip whitespace before and after text content - - *rewrite_prefixes*: set to true to replace namespace prefixes by "n{number}" - - *qname_aware_tags*: a set of qname aware tag names in which prefixes - should be replaced in text content - - *qname_aware_attrs*: a set of qname aware attribute names in which prefixes - should be replaced in text content - - *exclude_attrs*: a set of attribute names that should not be serialised - - *exclude_tags*: a set of tag names that should not be serialised - """ - - def __init__( - self, - write, - *, - with_comments=False, - strip_text=False, - rewrite_prefixes=False, - qname_aware_tags=None, - qname_aware_attrs=None, - exclude_attrs=None, - exclude_tags=None, - ): - self._write = write - self._data = [] - self._with_comments = with_comments - self._strip_text = strip_text - self._exclude_attrs = set(exclude_attrs) if exclude_attrs else None - self._exclude_tags = set(exclude_tags) if exclude_tags else None - - self._rewrite_prefixes = rewrite_prefixes - if qname_aware_tags: - self._qname_aware_tags = set(qname_aware_tags) - else: - self._qname_aware_tags = None - if qname_aware_attrs: - self._find_qname_aware_attrs = set(qname_aware_attrs).intersection - else: - self._find_qname_aware_attrs = None - - # Stack with globally and newly declared namespaces as (uri, prefix) pairs. - self._declared_ns_stack = [ - [ - ("http://www.w3.org/XML/1998/namespace", "xml"), - ] - ] - # Stack with user declared namespace prefixes as (uri, prefix) pairs. - self._ns_stack = [] - if not rewrite_prefixes: - # this must not be separated from ElementTree - self._ns_stack.append(list(ElementTree._namespace_map.items())) - self._ns_stack.append([]) - self._prefix_map = {} - self._preserve_space = [False] - self._pending_start = None - self._root_seen = False - self._root_done = False - self._ignored_depth = 0 - - def _iter_namespaces(self, ns_stack, _reversed=reversed): - for namespaces in _reversed(ns_stack): - if namespaces: # almost no element declares new namespaces - yield from namespaces - - def _resolve_prefix_name(self, prefixed_name): - prefix, name = prefixed_name.split(":", 1) - for uri, p in self._iter_namespaces(self._ns_stack): - if p == prefix: - return f"{{{uri}}}{name}" - raise ValueError( - f'Prefix {prefix} of QName "{prefixed_name}" is not declared in scope' - ) - - def _qname(self, qname, uri=None): - if uri is None: - uri, tag = qname[1:].rsplit("}", 1) if qname[:1] == "{" else ("", qname) - else: - tag = qname - - prefixes_seen = set() - for u, prefix in self._iter_namespaces(self._declared_ns_stack): - if u == uri and prefix not in prefixes_seen: - return f"{prefix}:{tag}" if prefix else tag, tag, uri - prefixes_seen.add(prefix) - - # Not declared yet => add new declaration. - if self._rewrite_prefixes: - if uri in self._prefix_map: - prefix = self._prefix_map[uri] - else: - prefix = self._prefix_map[uri] = f"n{len(self._prefix_map)}" - self._declared_ns_stack[-1].append((uri, prefix)) - return f"{prefix}:{tag}", tag, uri - - if not uri and "" not in prefixes_seen: - # No default namespace declared => no prefix needed. - return tag, tag, uri - - for u, prefix in self._iter_namespaces(self._ns_stack): - if u == uri: - self._declared_ns_stack[-1].append((uri, prefix)) - return f"{prefix}:{tag}" if prefix else tag, tag, uri - - raise ValueError(f'Namespace "{uri}" is not declared in scope') - - def data(self, data): - if not self._ignored_depth: - self._data.append(data) - - def _flush(self, _join_text="".join): - data = _join_text(self._data) - del self._data[:] - if self._strip_text and not self._preserve_space[-1]: - data = data.strip() - if self._pending_start is not None: - args, self._pending_start = self._pending_start, None - qname_text = data if data and _looks_like_prefix_name(data) else None - self._start(*args, qname_text) - if qname_text is not None: - return - if data and self._root_seen: - self._write(_escape_cdata_c14n(data)) - - def start_ns(self, prefix, uri): - if self._ignored_depth: - return - # we may have to resolve qnames in text content - if self._data: - self._flush() - self._ns_stack[-1].append((uri, prefix)) - - def start(self, tag, attrs): - if self._exclude_tags is not None and ( - self._ignored_depth or tag in self._exclude_tags - ): - self._ignored_depth += 1 - return - if self._data: - self._flush() - - new_namespaces = [] - self._declared_ns_stack.append(new_namespaces) - - if self._qname_aware_tags is not None and tag in self._qname_aware_tags: - # Need to parse text first to see if it requires a prefix declaration. - self._pending_start = (tag, attrs, new_namespaces) - return - self._start(tag, attrs, new_namespaces) - - def _start(self, tag, attrs, new_namespaces, qname_text=None): - if self._exclude_attrs is not None and attrs: - attrs = {k: v for k, v in attrs.items() if k not in self._exclude_attrs} - - qnames = {tag, *attrs} - resolved_names = {} - - # Resolve prefixes in attribute and tag text. - if qname_text is not None: - qname = resolved_names[qname_text] = self._resolve_prefix_name(qname_text) - qnames.add(qname) - if self._find_qname_aware_attrs is not None and attrs: - qattrs = self._find_qname_aware_attrs(attrs) - if qattrs: - for attr_name in qattrs: - value = attrs[attr_name] - if _looks_like_prefix_name(value): - qname = resolved_names[value] = self._resolve_prefix_name(value) - qnames.add(qname) - else: - qattrs = None - else: - qattrs = None - - # Assign prefixes in lexicographical order of used URIs. - parse_qname = self._qname - parsed_qnames = { - n: parse_qname(n) for n in sorted(qnames, key=lambda n: n.split("}", 1)) - } - - # Write namespace declarations in prefix order ... - if new_namespaces: - attr_list = [ - ("xmlns:" + prefix if prefix else "xmlns", uri) - for uri, prefix in new_namespaces - ] - attr_list.sort() - else: - # almost always empty - attr_list = [] - - # ... followed by attributes in URI+name order - if attrs: - for k, v in sorted(attrs.items()): - if qattrs is not None and k in qattrs and v in resolved_names: - v = parsed_qnames[resolved_names[v]][0] - attr_qname, attr_name, uri = parsed_qnames[k] - # No prefix for attributes in default ('') namespace. - attr_list.append((attr_qname if uri else attr_name, v)) - - # Honour xml:space attributes. - space_behaviour = attrs.get("{http://www.w3.org/XML/1998/namespace}space") - self._preserve_space.append( - space_behaviour == "preserve" - if space_behaviour - else self._preserve_space[-1] - ) - - # Write the tag. - write = self._write - write("<" + parsed_qnames[tag][0]) - if attr_list: - write("".join([f' {k}="{_escape_attrib_c14n(v)}"' for k, v in attr_list])) - write(">") - - # Write the resolved qname text content. - if qname_text is not None: - write(_escape_cdata_c14n(parsed_qnames[resolved_names[qname_text]][0])) - - self._root_seen = True - self._ns_stack.append([]) - - def end(self, tag): - if self._ignored_depth: - self._ignored_depth -= 1 - return - if self._data: - self._flush() - self._write(f"") - self._preserve_space.pop() - self._root_done = len(self._preserve_space) == 1 - self._declared_ns_stack.pop() - self._ns_stack.pop() - - def comment(self, text): - if not self._with_comments: - return - if self._ignored_depth: - return - if self._root_done: - self._write("\n") - elif self._root_seen and self._data: - self._flush() - self._write(f"") - if not self._root_seen: - self._write("\n") - - def pi(self, target, data): - if self._ignored_depth: - return - if self._root_done: - self._write("\n") - elif self._root_seen and self._data: - self._flush() - self._write( - f"" if data else f"" - ) - if not self._root_seen: - self._write("\n") - - -def _escape_cdata_c14n(text): - # escape character data - try: - # it's worth avoiding do-nothing calls for strings that are - # shorter than 500 character, or so. assume that's, by far, - # the most common case in most applications. - if "&" in text: - text = text.replace("&", "&") - if "<" in text: - text = text.replace("<", "<") - if ">" in text: - text = text.replace(">", ">") - if "\r" in text: - text = text.replace("\r", " ") - return text - except (TypeError, AttributeError): - ElementTree._raise_serialization_error(text) - - -def _escape_attrib_c14n(text): - # escape attribute value - try: - if "&" in text: - text = text.replace("&", "&") - if "<" in text: - text = text.replace("<", "<") - if '"' in text: - text = text.replace('"', """) - if "\t" in text: - text = text.replace("\t", " ") - if "\n" in text: - text = text.replace("\n", " ") - if "\r" in text: - text = text.replace("\r", " ") - return text - except (TypeError, AttributeError): - ElementTree._raise_serialization_error(text) diff --git a/tox.ini b/tox.ini deleted file mode 100644 index b374386a..00000000 --- a/tox.ini +++ /dev/null @@ -1,49 +0,0 @@ -[tox] -envlist = py{37,38,39,310,311}-{linux,macos,windows},pre-commit -toxworkdir = /tmp/.tox -isolated_build = true - -[gh-actions] -python = - 3.7: py37 - 3.8: py38 - 3.9: py39 - 3.10: py310 - 3.11: py311 - -[gh-actions:env] -platform = - ubuntu-latest: linux - macos-latest: macos - windows-latest: windows - -[testenv] -platform = - macos: darwin - linux: linux - windows: win32 -passenv = - CI - GITHUB_ACTIONS - CODECOV_TOKEN -extras = test -commands = pytest -v --basetemp={envtmpdir} {posargs} - -[testenv:benchmark] -passenv = - CI - GITHUB_ACTIONS -extras = test -commands = pytest -v --benchmark-enable --basetemp={envtmpdir} {posargs} - -[testenv:mypy] -deps = mypy -commands = mypy - -[testenv:lint] -deps = flake8 -commands = flake8 - -[testenv:pre-commit] -deps = pre-commit -commands = pre-commit run --all-files diff --git a/typesafety/test_type_inits.yml b/typesafety/test_type_inits.yml new file mode 100644 index 00000000..2dcbae85 --- /dev/null +++ b/typesafety/test_type_inits.yml @@ -0,0 +1,25 @@ +- case: types_requiring_no_arguments + main: | + import ome_types.model as m + + m.OME() + m.Annotation() + m.BasicAnnotation() + m.Dataset() + m.Arc() + m.Microscope() + ch = m.Channel() + reveal_type(ch.id) # N: Revealed type is "builtins.str" + +- case: types_requiring_arguments + main: | + import ome_types.model as m + + m.BinData(value=b'213', length=1) # ER: Missing named argument "big_endian" .* + m.Image() # ER: Missing named argument "pixels" .* + +- case: extra_arguments + main: | + import ome_types.model as m + + m.Channel(idd='123') # ER: Unexpected keyword argument "idd" .*