Skip to content

Commit

Permalink
add and test operations
Browse files Browse the repository at this point in the history
  • Loading branch information
melonora committed Nov 19, 2024
1 parent f27f910 commit deb167b
Show file tree
Hide file tree
Showing 6 changed files with 67 additions and 23 deletions.
28 changes: 27 additions & 1 deletion multiscale_spatial_image/multiscale_spatial_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@
from collections.abc import MutableMapping, Hashable
from pathlib import Path
from zarr.storage import BaseStore
from multiscale_spatial_image.operations import transpose
from multiscale_spatial_image.operations import (
transpose,
reindex_data_arrays,
assign_coords,
)


@register_datatree_accessor("msi")
Expand Down Expand Up @@ -137,3 +141,25 @@ def transpose(self, *dims: Hashable) -> DataTree:
reorder the dimensions to the order that the `dims` are specified..
"""
return self._dt.map_over_datasets(transpose, *dims)

def reindex_data_arrays(
self,
indexers,
method=None,
tolerance=None,
copy=False,
fill_value=None,
**indexer_kwargs,
):
return self._dt.map_over_datasets(
reindex_data_arrays,
indexers,
method,
tolerance,
copy,
fill_value,
*indexer_kwargs,
)

def assign_coords(self, coords, **coords_kwargs):
return self._dt.map_over_datasets(assign_coords, coords, *coords_kwargs)
4 changes: 2 additions & 2 deletions multiscale_spatial_image/operations/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .operations import assign_coords, transpose, reindex
from .operations import assign_coords, transpose, reindex_data_arrays

__all__ = ["assign_coords", "transpose", "reindex"]
__all__ = ["assign_coords", "transpose", "reindex_data_arrays"]
8 changes: 2 additions & 6 deletions multiscale_spatial_image/operations/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,5 @@ def transpose(ds: Dataset, *args: Any, **kwargs: Any) -> Dataset:


@skip_non_dimension_nodes
def reindex(ds: Dataset, *args: Any, **kwargs: Any) -> Dataset:
# A copy is required as a dataset view as used in map_over_datasets is not mutable
# TODO: Check whether setting item on wrapping datatree node would be better than copy or this can be dropped.
ds_copy = ds.copy()
ds_copy["image"] = ds_copy["image"].reindex(*args, **kwargs)
return ds_copy
def reindex_data_arrays(ds: Dataset, *args: Any, **kwargs: Any) -> Dataset:
return ds["image"].reindex(*args, **kwargs).to_dataset()
13 changes: 13 additions & 0 deletions test/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import pytest
import numpy as np
from spatial_image import to_spatial_image
from multiscale_spatial_image import to_multiscale


@pytest.fixture()
def multiscale_data():
data = np.zeros((3, 200, 200))
dims = ("c", "y", "x")
scale_factors = [2, 2]
image = to_spatial_image(array_like=data, dims=dims)
return to_multiscale(image, scale_factors=scale_factors)
17 changes: 17 additions & 0 deletions test/test_operations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
def test_transpose(multiscale_data):
multiscale_data = multiscale_data.msi.transpose("y", "x", "c")

for scale in list(multiscale_data.keys()):
assert multiscale_data[scale]["image"].dims == ("y", "x", "c")


def test_reindex_arrays(multiscale_data):
multiscale_data = multiscale_data.msi.reindex_data_arrays({"c": ["r", "g", "b"]})
for scale in list(multiscale_data.keys()):
assert multiscale_data[scale].c.data.tolist() == ["r", "g", "b"]


def test_assign_coords(multiscale_data):
multiscale_data = multiscale_data.msi.assign_coords({"c": ["r", "g", "b"]})
for scale in list(multiscale_data.keys()):
assert multiscale_data[scale].c.data.tolist() == ["r", "g", "b"]
20 changes: 6 additions & 14 deletions test/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,15 @@
import numpy as np
from spatial_image import to_spatial_image
from multiscale_spatial_image import skip_non_dimension_nodes, to_multiscale
from multiscale_spatial_image import skip_non_dimension_nodes


def test_skip_nodes():
data = np.zeros((2, 200, 200))
dims = ("c", "y", "x")
scale_factors = [2, 2]
image = to_spatial_image(array_like=data, dims=dims)
multiscale_img = to_multiscale(image, scale_factors=scale_factors)

def test_skip_nodes(multiscale_data):
@skip_non_dimension_nodes
def transpose(ds, *args, **kwargs):
return ds.transpose(*args, **kwargs)

for scale in list(multiscale_img.keys()):
assert multiscale_img[scale]["image"].dims == ("c", "y", "x")
multiscale_img.msi.transpose()
for scale in list(multiscale_data.keys()):
assert multiscale_data[scale]["image"].dims == ("c", "y", "x")

# applying this function without skipping the root node would fail as the root node does not have dimensions.
result = multiscale_img.map_over_datasets(transpose, "y", "x", "c")
result = multiscale_data.map_over_datasets(transpose, "y", "x", "c")
for scale in list(result.keys()):
assert result[scale]["image"].dims == ("y", "x", "c")

0 comments on commit deb167b

Please sign in to comment.