From fcd6286e0ddeddd73e9bf8123cba1972a307d531 Mon Sep 17 00:00:00 2001 From: Tom White Date: Fri, 5 Apr 2024 11:35:38 +0100 Subject: [PATCH] Implement virtual array indexing using ndindex --- cubed/core/ops.py | 17 +++--- cubed/runtime/executors/modal.py | 2 + cubed/storage/virtual.py | 95 ++++++++----------------------- cubed/storage/zarr.py | 24 +++++--- cubed/tests/runtime/test_modal.py | 1 + cubed/tests/test_indexing.py | 10 ++-- pyproject.toml | 1 + requirements.txt | 1 + setup.cfg | 2 + 9 files changed, 62 insertions(+), 91 deletions(-) diff --git a/cubed/core/ops.py b/cubed/core/ops.py index 848587a7..90df1f72 100644 --- a/cubed/core/ops.py +++ b/cubed/core/ops.py @@ -17,7 +17,6 @@ OrthogonalIndexer, SliceDimIndexer, is_integer_list, - is_slice, replace_ellipsis, ) @@ -409,7 +408,7 @@ def index(x, key): key = (key,) # No op case - if all(is_slice(ind) and ind == slice(None) for ind in key): + if all(isinstance(ind, slice) and ind == slice(None) for ind in key): return x # Remove None values, to be filled in with expand_dims at end @@ -436,7 +435,9 @@ def index(x, key): selection = replace_ellipsis(selection, x.shape) # Check selection is supported - if any(s.step is not None and s.step < 1 for s in selection if is_slice(s)): + if any( + s.step is not None and s.step < 1 for s in selection if isinstance(s, slice) + ): raise NotImplementedError(f"Slice step must be >= 1: {key}") assert all(isinstance(s, (slice, list, Integral)) for s in selection) where_list = [i for i, ind in enumerate(selection) if is_integer_list(ind)] @@ -490,7 +491,6 @@ def merged_chunk_len_for_indexer(s): extra_projected_mem=extra_projected_mem, target_chunks=target_chunks, selection=selection, - advanced_indexing=len(where_list) > 0, ) # merge chunks for any dims with step > 1 so they are @@ -516,13 +516,14 @@ def _read_index_chunk( *arrays, target_chunks=None, selection=None, - advanced_indexing=None, block_id=None, ): array = arrays[0].zarray - if advanced_indexing: - array = array.oindex idx = block_id + # Note that since we only have a maximum of one integer array index + # we don't need to use Zarr orthogonal indexing, since it is + # "available directly on the array" according to + # https://zarr.readthedocs.io/en/stable/tutorial.html#orthogonal-indexing out = array[_target_chunk_selection(target_chunks, idx, selection)] out = numpy_array_to_backend_array(out) return out @@ -535,7 +536,7 @@ def _target_chunk_selection(target_chunks, idx, selection): sel = [] i = 0 # index into target_chunks and idx for s in selection: - if is_slice(s): + if isinstance(s, slice): offset = s.start or 0 step = s.step if s.step is not None else 1 start = tuple( diff --git a/cubed/runtime/executors/modal.py b/cubed/runtime/executors/modal.py index c5a767d3..5bdb101c 100644 --- a/cubed/runtime/executors/modal.py +++ b/cubed/runtime/executors/modal.py @@ -38,6 +38,7 @@ "donfig", "fsspec", "mypy_extensions", # for rechunker + "ndindex", "networkx", "pytest-mock", # TODO: only needed for tests "s3fs", @@ -52,6 +53,7 @@ "donfig", "fsspec", "mypy_extensions", # for rechunker + "ndindex", "networkx", "pytest-mock", # TODO: only needed for tests "gcsfs", diff --git a/cubed/storage/virtual.py b/cubed/storage/virtual.py index cc5f2ad2..2c98d3ad 100644 --- a/cubed/storage/virtual.py +++ b/cubed/storage/virtual.py @@ -2,10 +2,8 @@ from typing import Any import numpy as np -import zarr -from zarr.indexing import BasicIndexer, is_slice +from ndindex import ndindex -from cubed.backend_array_api import backend_array_to_numpy_array from cubed.backend_array_api import namespace as nxp from cubed.backend_array_api import numpy_array_to_backend_array from cubed.types import T_DType, T_RegularChunks, T_Shape @@ -21,31 +19,21 @@ def __init__( dtype: T_DType, chunks: T_RegularChunks, ): - # use an empty in-memory Zarr array as a template since it normalizes its properties - template = zarr.empty( - shape, dtype=dtype, chunks=chunks, store=zarr.storage.MemoryStore() - ) - self.shape = template.shape - self.dtype = template.dtype - self.chunks = template.chunks - self.template = template + self.shape = shape + self.dtype = np.dtype(dtype) + self.chunks = chunks def __getitem__(self, key): - if not isinstance(key, tuple): - key = (key,) - indexer = BasicIndexer(key, self.template) + idx = ndindex[key] + newshape = idx.newshape(self.shape) # use broadcast trick so array chunks only occupy a single value in memory - return broadcast_trick(nxp.empty)(indexer.shape, dtype=self.dtype) + return broadcast_trick(nxp.empty)(newshape, dtype=self.dtype) @property def chunkmem(self): # take broadcast trick into account return array_memory(self.dtype, (1,)) - @property - def oindex(self): - return self.template.oindex - class VirtualFullArray: """An array that is never materialized (in memory or on disk) and contains a single fill value.""" @@ -57,38 +45,19 @@ def __init__( chunks: T_RegularChunks, fill_value: Any = None, ): - # use an empty in-memory Zarr array as a template since it normalizes its properties - template = zarr.full( - shape, - fill_value, - dtype=dtype, - chunks=chunks, - store=zarr.storage.MemoryStore(), - ) - self.shape = template.shape - self.dtype = template.dtype - self.chunks = template.chunks - self.template = template + self.shape = shape + self.dtype = np.dtype(dtype) + self.chunks = chunks self.fill_value = fill_value def __getitem__(self, key): - if not isinstance(key, tuple): - key = (key,) - indexer = BasicIndexer(key, self.template) + idx = ndindex[key] + newshape = idx.newshape(self.shape) # use broadcast trick so array chunks only occupy a single value in memory return broadcast_trick(nxp.full)( - indexer.shape, fill_value=self.fill_value, dtype=self.dtype + newshape, fill_value=self.fill_value, dtype=self.dtype ) - @property - def chunkmem(self): - # take broadcast trick into account - return array_memory(self.dtype, (1,)) - - @property - def oindex(self): - return self.template.oindex - class VirtualOffsetsArray: """An array that is never materialized (in memory or on disk) and contains sequentially incrementing integers.""" @@ -96,14 +65,9 @@ class VirtualOffsetsArray: def __init__(self, shape: T_Shape): dtype = nxp.int32 chunks = (1,) * len(shape) - # use an empty in-memory Zarr array as a template since it normalizes its properties - template = zarr.empty( - shape, dtype=dtype, chunks=chunks, store=zarr.storage.MemoryStore() - ) - self.shape = template.shape - self.dtype = template.dtype - self.chunks = template.chunks - self.ndim = template.ndim + self.shape = shape + self.dtype = np.dtype(dtype) + self.chunks = chunks def __getitem__(self, key): if key == () and self.shape == (): @@ -127,28 +91,13 @@ def __init__( f"Size of in memory array is {memory_repr(array.nbytes)} which exceeds maximum of {memory_repr(max_nbytes)}. Consider loading the array from storage using `from_array`." ) self.array = array - # use an in-memory Zarr array as a template since it normalizes its properties - # and is needed for oindex - template = zarr.empty( - array.shape, - dtype=array.dtype, - chunks=chunks, - store=zarr.storage.MemoryStore(), - ) - self.shape = template.shape - self.dtype = template.dtype - self.chunks = template.chunks - self.template = template - if array.size > 0: - template[...] = backend_array_to_numpy_array(array) + self.shape = array.shape + self.dtype = array.dtype + self.chunks = chunks def __getitem__(self, key): return self.array.__getitem__(key) - @property - def oindex(self): - return self.template.oindex - def _key_to_index_tuple(selection): if isinstance(selection, slice): @@ -158,7 +107,11 @@ def _key_to_index_tuple(selection): for s in selection: if isinstance(s, Integral): sel.append(s) - elif is_slice(s) and s.stop == s.start + 1 and (s.step is None or s.step == 1): + elif ( + isinstance(s, slice) + and s.stop == s.start + 1 + and (s.step is None or s.step == 1) + ): sel.append(s.start) else: raise NotImplementedError(f"Offset selection not supported: {selection}") diff --git a/cubed/storage/zarr.py b/cubed/storage/zarr.py index af048d78..136740b2 100644 --- a/cubed/storage/zarr.py +++ b/cubed/storage/zarr.py @@ -1,6 +1,9 @@ +from operator import mul from typing import Optional, Union +import numpy as np import zarr +from toolz import reduce from cubed.types import T_DType, T_RegularChunks, T_Shape, T_Store @@ -23,18 +26,23 @@ def __init__( **kwargs, ): """Create a Zarr array lazily in memory.""" - # use an empty in-memory Zarr array as a template since it normalizes its properties - template = zarr.empty( - shape, dtype=dtype, chunks=chunks, store=zarr.storage.MemoryStore() - ) - self.shape = template.shape - self.dtype = template.dtype - self.chunks = template.chunks - self.nbytes = template.nbytes + self.shape = shape + self.dtype = np.dtype(dtype) + self.chunks = chunks self.store = store self.path = path self.kwargs = kwargs + @property + def size(self): + """Number of elements in the array.""" + return reduce(mul, self.shape, 1) + + @property + def nbytes(self) -> int: + """Number of bytes in array""" + return self.size * self.dtype.itemsize + def create(self, mode: str = "w-") -> zarr.Array: """Create the Zarr array in storage. diff --git a/cubed/tests/runtime/test_modal.py b/cubed/tests/runtime/test_modal.py index 1d66547f..6e341709 100644 --- a/cubed/tests/runtime/test_modal.py +++ b/cubed/tests/runtime/test_modal.py @@ -23,6 +23,7 @@ "donfig", "fsspec", "mypy_extensions", # for rechunker + "ndindex", "networkx", "pytest-mock", # TODO: only needed for tests "s3fs", diff --git a/cubed/tests/test_indexing.py b/cubed/tests/test_indexing.py index fc7b3d9f..e9ac27e5 100644 --- a/cubed/tests/test_indexing.py +++ b/cubed/tests/test_indexing.py @@ -23,8 +23,9 @@ def spec(tmp_path): ], ) def test_int_array_index_1d(spec, ind): - a = xp.arange(12, chunks=(4,), spec=spec) - assert_array_equal(a[ind].compute(), np.arange(12)[ind]) + a = xp.arange(12, chunks=(3,), spec=spec) + b = a.rechunk((4,)) # force materialization to test indexing against zarr + assert_array_equal(b[ind].compute(), np.arange(12)[ind]) @pytest.mark.parametrize( @@ -40,11 +41,12 @@ def test_int_array_index_1d(spec, ind): def test_int_array_index_2d(spec, ind): a = xp.asarray( [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]], - chunks=(2, 2), + chunks=(3, 3), spec=spec, ) + b = a.rechunk((2, 2)) # force materialization to test indexing against zarr x = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]]) - assert_array_equal(a[ind].compute(), x[ind]) + assert_array_equal(b[ind].compute(), x[ind]) def test_multiple_int_array_indexes(spec): diff --git a/pyproject.toml b/pyproject.toml index b850087d..9c28bc33 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,6 +30,7 @@ dependencies = [ "donfig", "fsspec", "mypy_extensions", # for rechunker + "ndindex", "networkx != 2.8.3, != 2.8.4, != 2.8.5, != 2.8.6, != 2.8.7, != 2.8.8, != 3.0.*, != 3.1.*, != 3.2.*", "numpy >= 1.22", "tenacity", diff --git a/requirements.txt b/requirements.txt index f4855566..6f84c347 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,6 +6,7 @@ aiostream array-api-compat fsspec mypy_extensions # for rechunker +ndindex networkx != 2.8.3, != 2.8.4, != 2.8.5, != 2.8.6, != 2.8.7, != 2.8.8, != 3.0.*, != 3.1.*, != 3.2.* numpy >= 1.22 tenacity diff --git a/setup.cfg b/setup.cfg index 62965f08..8606c1e8 100644 --- a/setup.cfg +++ b/setup.cfg @@ -44,6 +44,8 @@ ignore_missing_imports = True ignore_missing_imports = True [mypy-matplotlib.*] ignore_missing_imports = True +[mypy-ndindex.*] +ignore_missing_imports = True [mypy-networkx.*] ignore_missing_imports = True [mypy-numpy.*]