Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Made _locate_spatial_element public, renamed to locate_element() #427

Merged
merged 5 commits into from
Jan 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ and this project adheres to [Semantic Versioning][].
### Added

- added SpatialData.subset() API
- added SpatialData.locate_element() API

### Fixed

Expand Down
87 changes: 29 additions & 58 deletions src/spatialdata/_core/spatialdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,24 +312,19 @@ def _init_add_element(self, name: str, element_type: str, overwrite: bool) -> za
return elem_group
return root

def _locate_spatial_element(self, element: SpatialElement) -> tuple[str, str]:
def locate_element(self, element: SpatialElement) -> list[str] | None:
"""
Find the SpatialElement within the SpatialData object.
Locate a SpatialElement within the SpatialData object and, if found, returns its Zarr path relative to the root.

Parameters
----------
element
The queried SpatialElement


Returns
-------
name and type of the element

Raises
------
ValueError
the element is not found or found multiple times in the SpatialData object
A list of Zarr paths of the element relative to the root (multiple copies of the same element are allowed), or
None if the element is not found.
"""
found: list[SpatialElement] = []
found_element_type: list[str] = []
Expand All @@ -341,39 +336,8 @@ def _locate_spatial_element(self, element: SpatialElement) -> tuple[str, str]:
found_element_type.append(element_type)
found_element_name.append(element_name)
if len(found) == 0:
raise ValueError("Element not found in the SpatialData object.")
if len(found) > 1:
raise ValueError(
f"Element found multiple times in the SpatialData object."
f"Found {len(found)} elements with names: {found_element_name},"
f" and types: {found_element_type}"
)
assert len(found_element_name) == 1
assert len(found_element_type) == 1
return found_element_name[0], found_element_type[0]

def contains_element(self, element: SpatialElement, raise_exception: bool = False) -> bool:
"""
Check if the SpatialElement is contained in the SpatialData object.

Parameters
----------
element
The SpatialElement to check
raise_exception
If True, raise an exception if the element is not found. If False, return False if the element is not found.

Returns
-------
True if the element is found; False otherwise (if raise_exception is False).
"""
try:
self._locate_spatial_element(element)
return True
except ValueError as e:
if raise_exception:
raise e
return False
return None
LucaMarconato marked this conversation as resolved.
Show resolved Hide resolved
return [f"{found_element_type[i]}/{found_element_name[i]}" for i in range(len(found))]

def _write_transformations_to_disk(self, element: SpatialElement) -> None:
"""
Expand All @@ -388,25 +352,32 @@ def _write_transformations_to_disk(self, element: SpatialElement) -> None:

transformations = get_transformation(element, get_all=True)
assert isinstance(transformations, dict)
found_element_name, found_element_type = self._locate_spatial_element(element)

located = self.locate_element(element)
if located is None:
raise ValueError(
"Cannot save the transformation to the element as it has not been found in the SpatialData object"
)
if self.path is not None:
group = self._get_group_for_element(name=found_element_name, element_type=found_element_type)
axes = get_axes_names(element)
if isinstance(element, (SpatialImage, MultiscaleSpatialImage)):
from spatialdata._io._utils import (
overwrite_coordinate_transformations_raster,
)
for path in located:
found_element_type, found_element_name = path.split("/")
group = self._get_group_for_element(name=found_element_name, element_type=found_element_type)
axes = get_axes_names(element)
if isinstance(element, (SpatialImage, MultiscaleSpatialImage)):
from spatialdata._io._utils import (
overwrite_coordinate_transformations_raster,
)

overwrite_coordinate_transformations_raster(group=group, axes=axes, transformations=transformations)
elif isinstance(element, (DaskDataFrame, GeoDataFrame, AnnData)):
from spatialdata._io._utils import (
overwrite_coordinate_transformations_non_raster,
)
overwrite_coordinate_transformations_raster(group=group, axes=axes, transformations=transformations)
elif isinstance(element, (DaskDataFrame, GeoDataFrame, AnnData)):
from spatialdata._io._utils import (
overwrite_coordinate_transformations_non_raster,
)

overwrite_coordinate_transformations_non_raster(group=group, axes=axes, transformations=transformations)
else:
raise ValueError("Unknown element type")
overwrite_coordinate_transformations_non_raster(
group=group, axes=axes, transformations=transformations
)
else:
raise ValueError("Unknown element type")

def filter_by_coordinate_system(self, coordinate_system: str | list[str], filter_table: bool = True) -> SpatialData:
"""
Expand Down
8 changes: 4 additions & 4 deletions src/spatialdata/transformations/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ def set_transformation(
assert to_coordinate_system is None
_set_transformations(element, transformation)
else:
if not write_to_sdata.contains_element(element, raise_exception=True):
raise RuntimeError("contains_element() failed without raising an exception.")
if write_to_sdata.locate_element(element) is None:
raise RuntimeError("The element is not found in the SpatialData object.")
if not write_to_sdata.is_backed():
raise ValueError(
"The SpatialData object is not backed. You can either set a transformation to an element "
Expand Down Expand Up @@ -164,8 +164,8 @@ def remove_transformation(
assert to_coordinate_system is None
_set_transformations(element, {})
else:
if not write_to_sdata.contains_element(element, raise_exception=True):
raise RuntimeError("contains_element() failed without raising an exception.")
if write_to_sdata.locate_element(element) is None:
raise RuntimeError("The element is not found in the SpatialData object.")
if not write_to_sdata.is_backed():
raise ValueError(
"The SpatialData object is not backed. You can either remove a transformation from an "
Expand Down
9 changes: 4 additions & 5 deletions tests/core/operations/test_spatialdata_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,15 +310,14 @@ def test_concatenate_sdatas(full_sdata: SpatialData) -> None:


def test_locate_spatial_element(full_sdata: SpatialData) -> None:
assert full_sdata._locate_spatial_element(full_sdata.images["image2d"]) == ("image2d", "images")
assert full_sdata.locate_element(full_sdata.images["image2d"])[0] == "images/image2d"
im = full_sdata.images["image2d"]
del full_sdata.images["image2d"]
with pytest.raises(ValueError, match="Element not found in the SpatialData object."):
full_sdata._locate_spatial_element(im)
assert full_sdata.locate_element(im) is None
full_sdata.images["image2d"] = im
full_sdata.images["image2d_again"] = im
with pytest.raises(ValueError):
full_sdata._locate_spatial_element(im)
paths = full_sdata.locate_element(im)
assert len(paths) == 2


def test_get_item(points: SpatialData) -> None:
Expand Down
Loading