Skip to content

Commit

Permalink
Fix bug with expand dims of a scalar array (zarr-developers#103)
Browse files Browse the repository at this point in the history
* simple function to create manifestarray in tests

* test to expose bug with broadcast_to for scalars

* fix usage of outdated attribute name

* fix bug by special-casing broadcasting of scalar arrays
  • Loading branch information
TomNicholas authored May 8, 2024
1 parent 0f37222 commit 79d6969
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 7 deletions.
51 changes: 47 additions & 4 deletions virtualizarr/manifests/array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,8 @@ def concatenate(

# Ensure we handle axis being passed as a negative integer
first_arr = arrays[0]
axis = axis % first_arr.ndim
if axis < 0:
axis = axis % first_arr.ndim

arr_shapes = [arr.shape for arr in arrays]
_check_same_shapes_except_on_concat_axis(arr_shapes, axis)
Expand Down Expand Up @@ -154,6 +155,7 @@ def _check_same_ndims(ndims: list[int]) -> None:

def _check_same_shapes_except_on_concat_axis(shapes: list[tuple[int, ...]], axis: int):
"""Check that shapes are compatible for concatenation"""

shapes_without_concat_axis = [
_remove_element_at_position(shape, axis) for shape in shapes
]
Expand Down Expand Up @@ -198,7 +200,8 @@ def stack(

# Ensure we handle axis being passed as a negative integer
first_arr = arrays[0]
axis = axis % first_arr.ndim
if axis < 0:
axis = axis % first_arr.ndim

# find what new array shape must be
length_along_new_stacked_axis = len(arrays)
Expand Down Expand Up @@ -267,8 +270,13 @@ def broadcast_to(x: "ManifestArray", /, shape: Tuple[int, ...]) -> "ManifestArra
if d == d_requested:
pass
elif d is None:
# stack same array upon itself d_requested number of times, which inserts a new axis at axis=0
result = stack([result] * d_requested, axis=0)
if result.shape == ():
# scalars are a special case because their manifests already have a chunk key with one dimension
# see https://github.com/TomNicholas/VirtualiZarr/issues/100#issuecomment-2097058282
result = _broadcast_scalar(result, new_axis_length=d_requested)
else:
# stack same array upon itself d_requested number of times, which inserts a new axis at axis=0
result = stack([result] * d_requested, axis=0)
elif d == 1:
# concatenate same array upon itself d_requested number of times along existing axis
result = concatenate([result] * d_requested, axis=axis)
Expand All @@ -280,6 +288,41 @@ def broadcast_to(x: "ManifestArray", /, shape: Tuple[int, ...]) -> "ManifestArra
return result


def _broadcast_scalar(x: "ManifestArray", new_axis_length: int) -> "ManifestArray":
"""
Add an axis to a scalar ManifestArray, but without adding a new axis to the keys of the chunk manifest.
This is not the same as concatenation, because there is no existing axis along which to concatenate.
It's also not the same as stacking, because we don't want to insert a new axis into the chunk keys.
Scalars are a special case because their manifests still have a chunk key with one dimension.
See https://github.com/TomNicholas/VirtualiZarr/issues/100#issuecomment-2097058282
"""

from .array import ManifestArray

new_shape = (new_axis_length,)
new_chunks = (new_axis_length,)

concatenated_manifest = concat_manifests(
[x.manifest] * new_axis_length,
axis=0,
)

new_zarray = ZArray(
chunks=new_chunks,
compressor=x.zarray.compressor,
dtype=x.dtype,
fill_value=x.zarray.fill_value,
filters=x.zarray.filters,
shape=new_shape,
order=x.zarray.order,
zarr_format=x.zarray.zarr_format,
)

return ManifestArray(chunkmanifest=concatenated_manifest, zarray=new_zarray)


# TODO broadcast_arrays, squeeze, permute_dims


Expand Down
6 changes: 3 additions & 3 deletions virtualizarr/manifests/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,13 +99,13 @@ def __repr__(self) -> str:
return f"ChunkManifest<shape={self.shape_chunk_grid}>"

def __getitem__(self, key: ChunkKey) -> ChunkEntry:
return self.chunks[key]
return self.entries[key]

def __iter__(self) -> Iterator[ChunkKey]:
return iter(self.chunks.keys())
return iter(self.entries.keys())

def __len__(self) -> int:
return len(self.chunks)
return len(self.entries)

def dict(self) -> dict[str, dict[str, Union[str, int]]]:
"""Converts the entire manifest to a nested dictionary."""
Expand Down
35 changes: 35 additions & 0 deletions virtualizarr/tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import numpy as np

from virtualizarr.manifests import ChunkEntry, ChunkManifest, ManifestArray
from virtualizarr.zarr import ZArray


def create_manifestarray(
shape: tuple[int, ...], chunks: tuple[int, ...]
) -> ManifestArray:
"""
Create an example ManifestArray with sensible defaults.
"""

zarray = ZArray(
chunks=chunks,
compressor="zlib",
dtype=np.dtype("float32"),
fill_value=0.0, # TODO change this to NaN?
filters=None,
order="C",
shape=shape,
zarr_format=2,
)

if shape != ():
raise NotImplementedError(
"Only generation of array representing a single scalar currently supported"
)

# TODO generalize this
chunkmanifest = ChunkManifest(
entries={"0": ChunkEntry(path="scalar.nc", offset=6144, length=48)}
)

return ManifestArray(chunkmanifest=chunkmanifest, zarray=zarray)
11 changes: 11 additions & 0 deletions virtualizarr/tests/test_manifests/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import pytest

from virtualizarr.manifests import ChunkManifest, ManifestArray
from virtualizarr.tests import create_manifestarray
from virtualizarr.zarr import ZArray


Expand Down Expand Up @@ -122,6 +123,16 @@ def test_not_equal_chunk_entries(self):
def test_partly_equals(self): ...


class TestBroadcast:
def test_broadcast_scalar(self):
# regression test
marr = create_manifestarray(shape=(), chunks=())
expanded = np.broadcast_to(marr, shape=(1,))
assert expanded.shape == (1,)
assert expanded.chunks == (1,)
assert expanded.manifest == marr.manifest


# TODO we really need some kind of fixtures to generate useful example data
# The hard part is having an alternative way to get to the expected result of concatenation
class TestConcat:
Expand Down

0 comments on commit 79d6969

Please sign in to comment.