diff --git a/ome_zarr/io.py b/ome_zarr/io.py index c3ad9de5..67ab246e 100644 --- a/ome_zarr/io.py +++ b/ome_zarr/io.py @@ -6,15 +6,23 @@ import json import logging from pathlib import Path -from typing import List, Optional, Union +from typing import Callable, List, Optional, Union from urllib.parse import urljoin import dask.array as da +import zarr as zarrlib +from anndata import AnnData +from anndata._io.specs import IOSpec +from anndata.compat import H5Array, H5Group, ZarrArray, ZarrGroup + +# ** requires anndata==0.9.0.rc1 +from anndata.experimental import read_dispatched, read_elem from zarr.storage import FSStore from .format import CurrentFormat, Format, detect_format from .types import JSONDict +StorageType = Union[H5Array, H5Group, ZarrArray, ZarrGroup] LOGGER = logging.getLogger("ome_zarr.io") @@ -194,3 +202,30 @@ def parse_url( LOGGER.exception("exception on parsing (stacktrace at DEBUG)") LOGGER.debug("stacktrace:", exc_info=True) return None + + +# From https://anndata.readthedocs.io/en/latest/tutorials/ +# notebooks/%7Bread,write%7D_dispatched.html + + +def read_remote_anndata(store: FSStore, name: str) -> AnnData: + table_group = zarrlib.group(store=store, path=name) + + def callback( + func: Callable, elem_name: str, elem: StorageType, iospec: IOSpec + ) -> AnnData: + if iospec.encoding_type in ( + "dataframe", + "csr_matrix", + "csc_matrix", + "awkward-array", + ): + # Preventing recursing inside of these types + return read_elem(elem) + elif iospec.encoding_type == "array": + return da.from_zarr(elem) + else: + return func(elem) + + adata = read_dispatched(table_group, callback=callback) + return adata diff --git a/ome_zarr/reader.py b/ome_zarr/reader.py index 3592a358..b13dd896 100644 --- a/ome_zarr/reader.py +++ b/ome_zarr/reader.py @@ -2,22 +2,17 @@ import logging import math -import os from abc import ABC from typing import Any, Dict, Iterator, List, Optional, Type, Union, cast, overload import dask.array as da import numpy as np - -# experimental failed to import -# from anndata.experimental import read_dispatched, write_dispatched, read_elem from anndata import AnnData -from anndata._io import read_zarr from dask import delayed from .axes import Axes from .format import format_from_version -from .io import ZarrLocation +from .io import ZarrLocation, read_remote_anndata from .types import JSONDict LOGGER = logging.getLogger("ome_zarr.reader") @@ -187,32 +182,6 @@ def lookup(self, key: str, default: Any) -> Any: return self.zarr.root_attrs.get(key, default) -# From https://anndata.readthedocs.io/en/latest/tutorials/ -# notebooks/%7Bread,write%7D_dispatched.html - -# def read_dask(store, path): -# f = zarr.open(store, path=path, mode="r") - -# def callback(func, elem_name: str, elem, iospec): -# print("callback", iospec.encoding_type) -# if iospec.encoding_type in ( -# "dataframe", -# "csr_matrix", -# "csc_matrix", -# "awkward-array", -# ): -# # Preventing recursing inside of these types -# return read_elem(elem) -# elif iospec.encoding_type == "array": -# return da.from_zarr(elem) -# else: -# return func(elem) - -# adata = read_dispatched(f, callback=callback) - -# return adata - - class Tables(Spec): """Class to represent a "tables" group which only contains the name of subgroups which should be loaded as labeled images.""" @@ -226,15 +195,13 @@ def __init__(self, node: Node) -> None: super().__init__(node) table_names = self.lookup("tables", []) node.tables = {} - store = self.zarr.store for name in table_names: child_zarr = self.zarr.create(name) if child_zarr.exists(): node.add(child_zarr) - - # node.tables[name] = read_dask(store, name) - full_path = os.path.join(store.path, name) - node.tables[name] = read_zarr(full_path) + LOGGER.info("Reading anndata table: %s", name) + anndata_obj = read_remote_anndata(node.zarr.store, name) + node.tables[name] = anndata_obj class Labels(Spec): diff --git a/setup.py b/setup.py index 940870ad..2deca46e 100644 --- a/setup.py +++ b/setup.py @@ -13,7 +13,7 @@ def read(fname): install_requires: List[List[str]] = [] -install_requires += (["anndata"],) +install_requires += (["anndata=0.9.0.rc1"],) install_requires += (["dataclasses;python_version<'3.7'"],) install_requires += (["tifffile<2020.09.22;python_version<'3.7'"],) install_requires += (["numpy"],)