Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement virtual array indexing using ndindex #441

Merged
merged 1 commit into from
Jun 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions cubed/core/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
OrthogonalIndexer,
SliceDimIndexer,
is_integer_list,
is_slice,
replace_ellipsis,
)

Expand Down Expand Up @@ -409,7 +408,7 @@ def index(x, key):
key = (key,)

# No op case
if all(is_slice(ind) and ind == slice(None) for ind in key):
if all(isinstance(ind, slice) and ind == slice(None) for ind in key):
return x

# Remove None values, to be filled in with expand_dims at end
Expand All @@ -436,7 +435,9 @@ def index(x, key):
selection = replace_ellipsis(selection, x.shape)

# Check selection is supported
if any(s.step is not None and s.step < 1 for s in selection if is_slice(s)):
if any(
s.step is not None and s.step < 1 for s in selection if isinstance(s, slice)
):
raise NotImplementedError(f"Slice step must be >= 1: {key}")
assert all(isinstance(s, (slice, list, Integral)) for s in selection)
where_list = [i for i, ind in enumerate(selection) if is_integer_list(ind)]
Expand Down Expand Up @@ -490,7 +491,6 @@ def merged_chunk_len_for_indexer(s):
extra_projected_mem=extra_projected_mem,
target_chunks=target_chunks,
selection=selection,
advanced_indexing=len(where_list) > 0,
)

# merge chunks for any dims with step > 1 so they are
Expand All @@ -516,13 +516,14 @@ def _read_index_chunk(
*arrays,
target_chunks=None,
selection=None,
advanced_indexing=None,
block_id=None,
):
array = arrays[0].zarray
if advanced_indexing:
array = array.oindex
idx = block_id
# Note that since we only have a maximum of one integer array index
# we don't need to use Zarr orthogonal indexing, since it is
# "available directly on the array" according to
# https://zarr.readthedocs.io/en/stable/tutorial.html#orthogonal-indexing
out = array[_target_chunk_selection(target_chunks, idx, selection)]
out = numpy_array_to_backend_array(out)
return out
Expand All @@ -535,7 +536,7 @@ def _target_chunk_selection(target_chunks, idx, selection):
sel = []
i = 0 # index into target_chunks and idx
for s in selection:
if is_slice(s):
if isinstance(s, slice):
offset = s.start or 0
step = s.step if s.step is not None else 1
start = tuple(
Expand Down
2 changes: 2 additions & 0 deletions cubed/runtime/executors/modal.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
"donfig",
"fsspec",
"mypy_extensions", # for rechunker
"ndindex",
"networkx",
"pytest-mock", # TODO: only needed for tests
"s3fs",
Expand All @@ -52,6 +53,7 @@
"donfig",
"fsspec",
"mypy_extensions", # for rechunker
"ndindex",
"networkx",
"pytest-mock", # TODO: only needed for tests
"gcsfs",
Expand Down
95 changes: 24 additions & 71 deletions cubed/storage/virtual.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,8 @@
from typing import Any

import numpy as np
import zarr
from zarr.indexing import BasicIndexer, is_slice
from ndindex import ndindex

from cubed.backend_array_api import backend_array_to_numpy_array
from cubed.backend_array_api import namespace as nxp
from cubed.backend_array_api import numpy_array_to_backend_array
from cubed.types import T_DType, T_RegularChunks, T_Shape
Expand All @@ -21,31 +19,21 @@ def __init__(
dtype: T_DType,
chunks: T_RegularChunks,
):
# use an empty in-memory Zarr array as a template since it normalizes its properties
template = zarr.empty(
shape, dtype=dtype, chunks=chunks, store=zarr.storage.MemoryStore()
)
self.shape = template.shape
self.dtype = template.dtype
self.chunks = template.chunks
self.template = template
self.shape = shape
self.dtype = np.dtype(dtype)
self.chunks = chunks

def __getitem__(self, key):
if not isinstance(key, tuple):
key = (key,)
indexer = BasicIndexer(key, self.template)
idx = ndindex[key]
newshape = idx.newshape(self.shape)
# use broadcast trick so array chunks only occupy a single value in memory
return broadcast_trick(nxp.empty)(indexer.shape, dtype=self.dtype)
return broadcast_trick(nxp.empty)(newshape, dtype=self.dtype)

@property
def chunkmem(self):
# take broadcast trick into account
return array_memory(self.dtype, (1,))

@property
def oindex(self):
return self.template.oindex


class VirtualFullArray:
"""An array that is never materialized (in memory or on disk) and contains a single fill value."""
Expand All @@ -57,53 +45,29 @@ def __init__(
chunks: T_RegularChunks,
fill_value: Any = None,
):
# use an empty in-memory Zarr array as a template since it normalizes its properties
template = zarr.full(
shape,
fill_value,
dtype=dtype,
chunks=chunks,
store=zarr.storage.MemoryStore(),
)
self.shape = template.shape
self.dtype = template.dtype
self.chunks = template.chunks
self.template = template
self.shape = shape
self.dtype = np.dtype(dtype)
self.chunks = chunks
self.fill_value = fill_value

def __getitem__(self, key):
if not isinstance(key, tuple):
key = (key,)
indexer = BasicIndexer(key, self.template)
idx = ndindex[key]
newshape = idx.newshape(self.shape)
# use broadcast trick so array chunks only occupy a single value in memory
return broadcast_trick(nxp.full)(
indexer.shape, fill_value=self.fill_value, dtype=self.dtype
newshape, fill_value=self.fill_value, dtype=self.dtype
)

@property
def chunkmem(self):
# take broadcast trick into account
return array_memory(self.dtype, (1,))

@property
def oindex(self):
return self.template.oindex


class VirtualOffsetsArray:
"""An array that is never materialized (in memory or on disk) and contains sequentially incrementing integers."""

def __init__(self, shape: T_Shape):
dtype = nxp.int32
chunks = (1,) * len(shape)
# use an empty in-memory Zarr array as a template since it normalizes its properties
template = zarr.empty(
shape, dtype=dtype, chunks=chunks, store=zarr.storage.MemoryStore()
)
self.shape = template.shape
self.dtype = template.dtype
self.chunks = template.chunks
self.ndim = template.ndim
self.shape = shape
self.dtype = np.dtype(dtype)
self.chunks = chunks

def __getitem__(self, key):
if key == () and self.shape == ():
Expand All @@ -127,28 +91,13 @@ def __init__(
f"Size of in memory array is {memory_repr(array.nbytes)} which exceeds maximum of {memory_repr(max_nbytes)}. Consider loading the array from storage using `from_array`."
)
self.array = array
# use an in-memory Zarr array as a template since it normalizes its properties
# and is needed for oindex
template = zarr.empty(
array.shape,
dtype=array.dtype,
chunks=chunks,
store=zarr.storage.MemoryStore(),
)
self.shape = template.shape
self.dtype = template.dtype
self.chunks = template.chunks
self.template = template
if array.size > 0:
template[...] = backend_array_to_numpy_array(array)
self.shape = array.shape
self.dtype = array.dtype
self.chunks = chunks

def __getitem__(self, key):
return self.array.__getitem__(key)

@property
def oindex(self):
return self.template.oindex


def _key_to_index_tuple(selection):
if isinstance(selection, slice):
Expand All @@ -158,7 +107,11 @@ def _key_to_index_tuple(selection):
for s in selection:
if isinstance(s, Integral):
sel.append(s)
elif is_slice(s) and s.stop == s.start + 1 and (s.step is None or s.step == 1):
elif (
isinstance(s, slice)
and s.stop == s.start + 1
and (s.step is None or s.step == 1)
):
sel.append(s.start)
else:
raise NotImplementedError(f"Offset selection not supported: {selection}")
Expand Down
24 changes: 16 additions & 8 deletions cubed/storage/zarr.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from operator import mul
from typing import Optional, Union

import numpy as np
import zarr
from toolz import reduce

from cubed.types import T_DType, T_RegularChunks, T_Shape, T_Store

Expand All @@ -23,18 +26,23 @@ def __init__(
**kwargs,
):
"""Create a Zarr array lazily in memory."""
# use an empty in-memory Zarr array as a template since it normalizes its properties
template = zarr.empty(
shape, dtype=dtype, chunks=chunks, store=zarr.storage.MemoryStore()
)
self.shape = template.shape
self.dtype = template.dtype
self.chunks = template.chunks
self.nbytes = template.nbytes
self.shape = shape
self.dtype = np.dtype(dtype)
self.chunks = chunks
self.store = store
self.path = path
self.kwargs = kwargs

@property
def size(self):
"""Number of elements in the array."""
return reduce(mul, self.shape, 1)

@property
def nbytes(self) -> int:
"""Number of bytes in array"""
return self.size * self.dtype.itemsize

def create(self, mode: str = "w-") -> zarr.Array:
"""Create the Zarr array in storage.

Expand Down
1 change: 1 addition & 0 deletions cubed/tests/runtime/test_modal.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
"donfig",
"fsspec",
"mypy_extensions", # for rechunker
"ndindex",
"networkx",
"pytest-mock", # TODO: only needed for tests
"s3fs",
Expand Down
10 changes: 6 additions & 4 deletions cubed/tests/test_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@ def spec(tmp_path):
],
)
def test_int_array_index_1d(spec, ind):
a = xp.arange(12, chunks=(4,), spec=spec)
assert_array_equal(a[ind].compute(), np.arange(12)[ind])
a = xp.arange(12, chunks=(3,), spec=spec)
b = a.rechunk((4,)) # force materialization to test indexing against zarr
assert_array_equal(b[ind].compute(), np.arange(12)[ind])


@pytest.mark.parametrize(
Expand All @@ -40,11 +41,12 @@ def test_int_array_index_1d(spec, ind):
def test_int_array_index_2d(spec, ind):
a = xp.asarray(
[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]],
chunks=(2, 2),
chunks=(3, 3),
spec=spec,
)
b = a.rechunk((2, 2)) # force materialization to test indexing against zarr
x = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]])
assert_array_equal(a[ind].compute(), x[ind])
assert_array_equal(b[ind].compute(), x[ind])


def test_multiple_int_array_indexes(spec):
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ dependencies = [
"donfig",
"fsspec",
"mypy_extensions", # for rechunker
"ndindex",
"networkx != 2.8.3, != 2.8.4, != 2.8.5, != 2.8.6, != 2.8.7, != 2.8.8, != 3.0.*, != 3.1.*, != 3.2.*",
"numpy >= 1.22",
"tenacity",
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ aiostream
array-api-compat
fsspec
mypy_extensions # for rechunker
ndindex
networkx != 2.8.3, != 2.8.4, != 2.8.5, != 2.8.6, != 2.8.7, != 2.8.8, != 3.0.*, != 3.1.*, != 3.2.*
numpy >= 1.22
tenacity
Expand Down
2 changes: 2 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ ignore_missing_imports = True
ignore_missing_imports = True
[mypy-matplotlib.*]
ignore_missing_imports = True
[mypy-ndindex.*]
ignore_missing_imports = True
[mypy-networkx.*]
ignore_missing_imports = True
[mypy-numpy.*]
Expand Down
Loading