From 9200daebac4dea8ff1fa7bbbfb5a800792b06959 Mon Sep 17 00:00:00 2001 From: Luca Marconato Date: Tue, 9 Jan 2024 16:35:53 +0100 Subject: [PATCH 1/4] made _locate_spatial_element public, renamed to locate_element() --- src/spatialdata/_core/spatialdata.py | 41 +++++-------------- src/spatialdata/transformations/operations.py | 8 ++-- .../operations/test_spatialdata_operations.py | 7 ++-- 3 files changed, 18 insertions(+), 38 deletions(-) diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index 94027355..4c02bec7 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -312,24 +312,23 @@ def _init_add_element(self, name: str, element_type: str, overwrite: bool) -> za return elem_group return root - def _locate_spatial_element(self, element: SpatialElement) -> tuple[str, str]: + def locate_element(self, element: SpatialElement) -> tuple[str, str] | None: """ - Find the SpatialElement within the SpatialData object. + Return the name and the type of a SpatialElement within the SpatialData object. Parameters ---------- element The queried SpatialElement - Returns ------- - name and type of the element + name and type of the element; if the element is not found, None is returned instead Raises ------ ValueError - the element is not found or found multiple times in the SpatialData object + the element is found multiple times in the SpatialData object """ found: list[SpatialElement] = [] found_element_type: list[str] = [] @@ -341,7 +340,7 @@ def _locate_spatial_element(self, element: SpatialElement) -> tuple[str, str]: found_element_type.append(element_type) found_element_name.append(element_name) if len(found) == 0: - raise ValueError("Element not found in the SpatialData object.") + return None if len(found) > 1: raise ValueError( f"Element found multiple times in the SpatialData object." @@ -352,29 +351,6 @@ def _locate_spatial_element(self, element: SpatialElement) -> tuple[str, str]: assert len(found_element_type) == 1 return found_element_name[0], found_element_type[0] - def contains_element(self, element: SpatialElement, raise_exception: bool = False) -> bool: - """ - Check if the SpatialElement is contained in the SpatialData object. - - Parameters - ---------- - element - The SpatialElement to check - raise_exception - If True, raise an exception if the element is not found. If False, return False if the element is not found. - - Returns - ------- - True if the element is found; False otherwise (if raise_exception is False). - """ - try: - self._locate_spatial_element(element) - return True - except ValueError as e: - if raise_exception: - raise e - return False - def _write_transformations_to_disk(self, element: SpatialElement) -> None: """ Write transformations to disk for an element. @@ -388,7 +364,12 @@ def _write_transformations_to_disk(self, element: SpatialElement) -> None: transformations = get_transformation(element, get_all=True) assert isinstance(transformations, dict) - found_element_name, found_element_type = self._locate_spatial_element(element) + located = self.locate_element(element) + if located is None: + raise ValueError( + "Cannot save the transformation to the element as it has not been found in the SpatialData object" + ) + found_element_name, found_element_type = located if self.path is not None: group = self._get_group_for_element(name=found_element_name, element_type=found_element_type) diff --git a/src/spatialdata/transformations/operations.py b/src/spatialdata/transformations/operations.py index 61792777..102a9cae 100644 --- a/src/spatialdata/transformations/operations.py +++ b/src/spatialdata/transformations/operations.py @@ -68,8 +68,8 @@ def set_transformation( assert to_coordinate_system is None _set_transformations(element, transformation) else: - if not write_to_sdata.contains_element(element, raise_exception=True): - raise RuntimeError("contains_element() failed without raising an exception.") + if write_to_sdata.locate_element(element) is None: + raise RuntimeError("The element is not found in the SpatialData object.") if not write_to_sdata.is_backed(): raise ValueError( "The SpatialData object is not backed. You can either set a transformation to an element " @@ -164,8 +164,8 @@ def remove_transformation( assert to_coordinate_system is None _set_transformations(element, {}) else: - if not write_to_sdata.contains_element(element, raise_exception=True): - raise RuntimeError("contains_element() failed without raising an exception.") + if write_to_sdata.locate_element(element) is None: + raise RuntimeError("The element is not found in the SpatialData object.") if not write_to_sdata.is_backed(): raise ValueError( "The SpatialData object is not backed. You can either remove a transformation from an " diff --git a/tests/core/operations/test_spatialdata_operations.py b/tests/core/operations/test_spatialdata_operations.py index 51314937..63714f72 100644 --- a/tests/core/operations/test_spatialdata_operations.py +++ b/tests/core/operations/test_spatialdata_operations.py @@ -308,15 +308,14 @@ def test_concatenate_sdatas(full_sdata: SpatialData) -> None: def test_locate_spatial_element(full_sdata: SpatialData) -> None: - assert full_sdata._locate_spatial_element(full_sdata.images["image2d"]) == ("image2d", "images") + assert full_sdata.locate_element(full_sdata.images["image2d"]) == ("image2d", "images") im = full_sdata.images["image2d"] del full_sdata.images["image2d"] - with pytest.raises(ValueError, match="Element not found in the SpatialData object."): - full_sdata._locate_spatial_element(im) + assert full_sdata.locate_element(im) is None full_sdata.images["image2d"] = im full_sdata.images["image2d_again"] = im with pytest.raises(ValueError): - full_sdata._locate_spatial_element(im) + full_sdata.locate_element(im) def test_get_item(points: SpatialData) -> None: From cc8d1738ed61d1bab98f3976066b0c50bec53dcf Mon Sep 17 00:00:00 2001 From: Luca Marconato Date: Tue, 16 Jan 2024 20:12:22 +0100 Subject: [PATCH 2/4] returning path instead of tuple in locate_element() --- src/spatialdata/_core/spatialdata.py | 10 +++++----- tests/core/operations/test_spatialdata_operations.py | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index 0766d52b..e26681cc 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -312,9 +312,9 @@ def _init_add_element(self, name: str, element_type: str, overwrite: bool) -> za return elem_group return root - def locate_element(self, element: SpatialElement) -> tuple[str, str] | None: + def locate_element(self, element: SpatialElement) -> str | None: """ - Return the name and the type of a SpatialElement within the SpatialData object. + Locate a SpatialElement within the SpatialData object and, if found, returns its Zarr path relative to the root. Parameters ---------- @@ -323,7 +323,7 @@ def locate_element(self, element: SpatialElement) -> tuple[str, str] | None: Returns ------- - name and type of the element; if the element is not found, None is returned instead + The Zarr path of the element relative to the root, or None if the element is not found. Raises ------ @@ -349,7 +349,7 @@ def locate_element(self, element: SpatialElement) -> tuple[str, str] | None: ) assert len(found_element_name) == 1 assert len(found_element_type) == 1 - return found_element_name[0], found_element_type[0] + return f"{found_element_type[0]}/{found_element_name[0]}" def _write_transformations_to_disk(self, element: SpatialElement) -> None: """ @@ -369,7 +369,7 @@ def _write_transformations_to_disk(self, element: SpatialElement) -> None: raise ValueError( "Cannot save the transformation to the element as it has not been found in the SpatialData object" ) - found_element_name, found_element_type = located + found_element_type, found_element_name = located.split("/") if self.path is not None: group = self._get_group_for_element(name=found_element_name, element_type=found_element_type) diff --git a/tests/core/operations/test_spatialdata_operations.py b/tests/core/operations/test_spatialdata_operations.py index 6d2dd9ad..84a41f60 100644 --- a/tests/core/operations/test_spatialdata_operations.py +++ b/tests/core/operations/test_spatialdata_operations.py @@ -310,7 +310,7 @@ def test_concatenate_sdatas(full_sdata: SpatialData) -> None: def test_locate_spatial_element(full_sdata: SpatialData) -> None: - assert full_sdata.locate_element(full_sdata.images["image2d"]) == ("image2d", "images") + assert full_sdata.locate_element(full_sdata.images["image2d"]) == "images/image2d" im = full_sdata.images["image2d"] del full_sdata.images["image2d"] assert full_sdata.locate_element(im) is None From 3d5025242de6de7674da1f797382637cfca18528 Mon Sep 17 00:00:00 2001 From: Luca Marconato Date: Tue, 16 Jan 2024 20:13:36 +0100 Subject: [PATCH 3/4] updated changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6d763506..f89cee68 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,7 @@ and this project adheres to [Semantic Versioning][]. ### Added - added SpatialData.subset() API +- added SpatialData.locate_element() API ### Fixed From 615bd79304d3a8166f76e176114b47394e5a088f Mon Sep 17 00:00:00 2001 From: Luca Marconato Date: Wed, 17 Jan 2024 13:32:22 +0100 Subject: [PATCH 4/4] locate_elements() now returns a list --- src/spatialdata/_core/spatialdata.py | 54 ++++++++----------- .../operations/test_spatialdata_operations.py | 6 +-- 2 files changed, 25 insertions(+), 35 deletions(-) diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index e26681cc..1d5b149e 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -312,7 +312,7 @@ def _init_add_element(self, name: str, element_type: str, overwrite: bool) -> za return elem_group return root - def locate_element(self, element: SpatialElement) -> str | None: + def locate_element(self, element: SpatialElement) -> list[str] | None: """ Locate a SpatialElement within the SpatialData object and, if found, returns its Zarr path relative to the root. @@ -323,12 +323,8 @@ def locate_element(self, element: SpatialElement) -> str | None: Returns ------- - The Zarr path of the element relative to the root, or None if the element is not found. - - Raises - ------ - ValueError - the element is found multiple times in the SpatialData object + A list of Zarr paths of the element relative to the root (multiple copies of the same element are allowed), or + None if the element is not found. """ found: list[SpatialElement] = [] found_element_type: list[str] = [] @@ -341,15 +337,7 @@ def locate_element(self, element: SpatialElement) -> str | None: found_element_name.append(element_name) if len(found) == 0: return None - if len(found) > 1: - raise ValueError( - f"Element found multiple times in the SpatialData object." - f"Found {len(found)} elements with names: {found_element_name}," - f" and types: {found_element_type}" - ) - assert len(found_element_name) == 1 - assert len(found_element_type) == 1 - return f"{found_element_type[0]}/{found_element_name[0]}" + return [f"{found_element_type[i]}/{found_element_name[i]}" for i in range(len(found))] def _write_transformations_to_disk(self, element: SpatialElement) -> None: """ @@ -369,25 +357,27 @@ def _write_transformations_to_disk(self, element: SpatialElement) -> None: raise ValueError( "Cannot save the transformation to the element as it has not been found in the SpatialData object" ) - found_element_type, found_element_name = located.split("/") - if self.path is not None: - group = self._get_group_for_element(name=found_element_name, element_type=found_element_type) - axes = get_axes_names(element) - if isinstance(element, (SpatialImage, MultiscaleSpatialImage)): - from spatialdata._io._utils import ( - overwrite_coordinate_transformations_raster, - ) + for path in located: + found_element_type, found_element_name = path.split("/") + group = self._get_group_for_element(name=found_element_name, element_type=found_element_type) + axes = get_axes_names(element) + if isinstance(element, (SpatialImage, MultiscaleSpatialImage)): + from spatialdata._io._utils import ( + overwrite_coordinate_transformations_raster, + ) - overwrite_coordinate_transformations_raster(group=group, axes=axes, transformations=transformations) - elif isinstance(element, (DaskDataFrame, GeoDataFrame, AnnData)): - from spatialdata._io._utils import ( - overwrite_coordinate_transformations_non_raster, - ) + overwrite_coordinate_transformations_raster(group=group, axes=axes, transformations=transformations) + elif isinstance(element, (DaskDataFrame, GeoDataFrame, AnnData)): + from spatialdata._io._utils import ( + overwrite_coordinate_transformations_non_raster, + ) - overwrite_coordinate_transformations_non_raster(group=group, axes=axes, transformations=transformations) - else: - raise ValueError("Unknown element type") + overwrite_coordinate_transformations_non_raster( + group=group, axes=axes, transformations=transformations + ) + else: + raise ValueError("Unknown element type") def filter_by_coordinate_system(self, coordinate_system: str | list[str], filter_table: bool = True) -> SpatialData: """ diff --git a/tests/core/operations/test_spatialdata_operations.py b/tests/core/operations/test_spatialdata_operations.py index 84a41f60..861acee2 100644 --- a/tests/core/operations/test_spatialdata_operations.py +++ b/tests/core/operations/test_spatialdata_operations.py @@ -310,14 +310,14 @@ def test_concatenate_sdatas(full_sdata: SpatialData) -> None: def test_locate_spatial_element(full_sdata: SpatialData) -> None: - assert full_sdata.locate_element(full_sdata.images["image2d"]) == "images/image2d" + assert full_sdata.locate_element(full_sdata.images["image2d"])[0] == "images/image2d" im = full_sdata.images["image2d"] del full_sdata.images["image2d"] assert full_sdata.locate_element(im) is None full_sdata.images["image2d"] = im full_sdata.images["image2d_again"] = im - with pytest.raises(ValueError): - full_sdata.locate_element(im) + paths = full_sdata.locate_element(im) + assert len(paths) == 2 def test_get_item(points: SpatialData) -> None: