Skip to content

Commit

Permalink
Use anndata.experimental methods to read remote anndata via dask
Browse files Browse the repository at this point in the history
  • Loading branch information
will-moore committed Mar 13, 2023
1 parent 17414e3 commit 3163412
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 39 deletions.
37 changes: 36 additions & 1 deletion ome_zarr/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,23 @@
import json
import logging
from pathlib import Path
from typing import List, Optional, Union
from typing import Callable, List, Optional, Union
from urllib.parse import urljoin

import dask.array as da
import zarr as zarrlib
from anndata import AnnData
from anndata._io.specs import IOSpec
from anndata.compat import H5Array, H5Group, ZarrArray, ZarrGroup

# ** requires anndata==0.9.0.rc1
from anndata.experimental import read_dispatched, read_elem
from zarr.storage import FSStore

from .format import CurrentFormat, Format, detect_format
from .types import JSONDict

StorageType = Union[H5Array, H5Group, ZarrArray, ZarrGroup]
LOGGER = logging.getLogger("ome_zarr.io")


Expand Down Expand Up @@ -194,3 +202,30 @@ def parse_url(
LOGGER.exception("exception on parsing (stacktrace at DEBUG)")
LOGGER.debug("stacktrace:", exc_info=True)
return None


# From https://anndata.readthedocs.io/en/latest/tutorials/
# notebooks/%7Bread,write%7D_dispatched.html


def read_remote_anndata(store: FSStore, name: str) -> AnnData:
table_group = zarrlib.group(store=store, path=name)

def callback(
func: Callable, elem_name: str, elem: StorageType, iospec: IOSpec
) -> AnnData:
if iospec.encoding_type in (
"dataframe",
"csr_matrix",
"csc_matrix",
"awkward-array",
):
# Preventing recursing inside of these types
return read_elem(elem)
elif iospec.encoding_type == "array":
return da.from_zarr(elem)
else:
return func(elem)

adata = read_dispatched(table_group, callback=callback)
return adata
41 changes: 4 additions & 37 deletions ome_zarr/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,17 @@

import logging
import math
import os
from abc import ABC
from typing import Any, Dict, Iterator, List, Optional, Type, Union, cast, overload

import dask.array as da
import numpy as np

# experimental failed to import
# from anndata.experimental import read_dispatched, write_dispatched, read_elem
from anndata import AnnData
from anndata._io import read_zarr
from dask import delayed

from .axes import Axes
from .format import format_from_version
from .io import ZarrLocation
from .io import ZarrLocation, read_remote_anndata
from .types import JSONDict

LOGGER = logging.getLogger("ome_zarr.reader")
Expand Down Expand Up @@ -187,32 +182,6 @@ def lookup(self, key: str, default: Any) -> Any:
return self.zarr.root_attrs.get(key, default)


# From https://anndata.readthedocs.io/en/latest/tutorials/
# notebooks/%7Bread,write%7D_dispatched.html

# def read_dask(store, path):
# f = zarr.open(store, path=path, mode="r")

# def callback(func, elem_name: str, elem, iospec):
# print("callback", iospec.encoding_type)
# if iospec.encoding_type in (
# "dataframe",
# "csr_matrix",
# "csc_matrix",
# "awkward-array",
# ):
# # Preventing recursing inside of these types
# return read_elem(elem)
# elif iospec.encoding_type == "array":
# return da.from_zarr(elem)
# else:
# return func(elem)

# adata = read_dispatched(f, callback=callback)

# return adata


class Tables(Spec):
"""Class to represent a "tables" group which only
contains the name of subgroups which should be loaded as labeled images."""
Expand All @@ -226,15 +195,13 @@ def __init__(self, node: Node) -> None:
super().__init__(node)
table_names = self.lookup("tables", [])
node.tables = {}
store = self.zarr.store
for name in table_names:
child_zarr = self.zarr.create(name)
if child_zarr.exists():
node.add(child_zarr)

# node.tables[name] = read_dask(store, name)
full_path = os.path.join(store.path, name)
node.tables[name] = read_zarr(full_path)
LOGGER.info("Reading anndata table: %s", name)
anndata_obj = read_remote_anndata(node.zarr.store, name)
node.tables[name] = anndata_obj


class Labels(Spec):
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def read(fname):


install_requires: List[List[str]] = []
install_requires += (["anndata"],)
install_requires += (["anndata=0.9.0.rc1"],)
install_requires += (["dataclasses;python_version<'3.7'"],)
install_requires += (["tifffile<2020.09.22;python_version<'3.7'"],)
install_requires += (["numpy"],)
Expand Down

0 comments on commit 3163412

Please sign in to comment.