diff --git a/CHANGELOG.md b/CHANGELOG.md index b3f7aebe..533586a0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,10 @@ and this project adheres to [Semantic Versioning][]. ## [0.0.x] - tbd +### Minor + +- improved usability and robustness of sdata.write() when overwrite=True @aeisenbarth + ### Added ### Fixed diff --git a/docs/api.md b/docs/api.md index 9034b0d9..93509ffd 100644 --- a/docs/api.md +++ b/docs/api.md @@ -29,6 +29,7 @@ Operations on `SpatialData` objects. get_extent match_table_to_element concatenate + transform rasterize aggregate ``` @@ -133,4 +134,5 @@ The transformations that can be defined between elements and coordinate systems read_zarr save_transformations + get_dask_backing_files ``` diff --git a/src/spatialdata/__init__.py b/src/spatialdata/__init__.py index 0541c491..e09f42c0 100644 --- a/src/spatialdata/__init__.py +++ b/src/spatialdata/__init__.py @@ -28,6 +28,7 @@ "read_zarr", "unpad_raster", "save_transformations", + "get_dask_backing_files", ] from spatialdata import dataloader, models, transformations @@ -40,6 +41,6 @@ from spatialdata._core.query.relational_query import get_values, match_table_to_element from spatialdata._core.query.spatial_query import bounding_box_query, polygon_query from spatialdata._core.spatialdata import SpatialData -from spatialdata._io._utils import save_transformations +from spatialdata._io._utils import get_dask_backing_files, save_transformations from spatialdata._io.io_zarr import read_zarr from spatialdata._utils import unpad_raster diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index df6e57f6..17672120 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -571,6 +571,7 @@ def write( consolidate_metadata: bool = True, ) -> None: from spatialdata._io import write_image, write_labels, write_points, write_shapes, write_table + from spatialdata._io._utils import get_dask_backing_files """Write the SpatialData object to Zarr.""" if isinstance(file_path, str): @@ -583,6 +584,7 @@ def write( # old code to support overwriting the backing file # target_path = None # tmp_zarr_file = None + if os.path.exists(file_path): if parse_url(file_path, mode="r") is None: raise ValueError( @@ -590,14 +592,22 @@ def write( "a Zarr store. Overwriting non-Zarr stores is not supported to prevent accidental " "data loss." ) - if not overwrite and str(self.path) != str(file_path): + if not overwrite: raise ValueError("The Zarr store already exists. Use `overwrite=True` to overwrite the store.") - raise ValueError( - "The file path specified is the same as the one used for backing. " - "Overwriting the backing file is not supported to prevent accidental data loss." - "We are discussing how to support this use case in the future, if you would like us to " - "support it please leave a comment on https://github.com/scverse/spatialdata/pull/138" - ) + if self.is_backed() and str(self.path) == str(file_path): + raise ValueError( + "The file path specified is the same as the one used for backing. " + "Overwriting the backing file is not supported to prevent accidental data loss." + "We are discussing how to support this use case in the future, if you would like us to " + "support it please leave a comment on https://github.com/scverse/spatialdata/pull/138" + ) + if any(Path(fp).resolve().is_relative_to(file_path.resolve()) for fp in get_dask_backing_files(self)): + raise ValueError( + "The file path specified is a parent directory of one or more files used for backing for one or " + "more elements in the SpatialData object. You can either load every element of the SpatialData " + "object in memory, or save the current spatialdata object to a different path." + ) + # old code to support overwriting the backing file # else: # target_path = tempfile.TemporaryDirectory() diff --git a/src/spatialdata/_io/__init__.py b/src/spatialdata/_io/__init__.py index fd72da5c..d9fc3cd6 100644 --- a/src/spatialdata/_io/__init__.py +++ b/src/spatialdata/_io/__init__.py @@ -1,3 +1,4 @@ +from spatialdata._io._utils import get_dask_backing_files from spatialdata._io.format import SpatialDataFormatV01 from spatialdata._io.io_points import write_points from spatialdata._io.io_raster import write_image, write_labels @@ -11,4 +12,5 @@ "write_shapes", "write_table", "SpatialDataFormatV01", + "get_dask_backing_files", ] diff --git a/src/spatialdata/_io/_utils.py b/src/spatialdata/_io/_utils.py index c2d44114..37be41fa 100644 --- a/src/spatialdata/_io/_utils.py +++ b/src/spatialdata/_io/_utils.py @@ -11,12 +11,12 @@ from typing import Any import zarr +from dask.array.core import Array as DaskArray from dask.dataframe.core import DataFrame as DaskDataFrame from multiscale_spatial_image import MultiscaleSpatialImage from ome_zarr.format import Format from ome_zarr.writer import _get_valid_axes from spatial_image import SpatialImage -from xarray import DataArray from spatialdata._core.spatialdata import SpatialData from spatialdata._utils import iterate_pyramid_levels @@ -203,37 +203,59 @@ def _compare_sdata_on_disk(a: SpatialData, b: SpatialData) -> bool: return _are_directories_identical(os.path.join(tmpdir, "a.zarr"), os.path.join(tmpdir, "b.zarr")) -def _get_backing_files_raster(raster: DataArray) -> list[str]: - files = [] - for k, v in raster.data.dask.layers.items(): - if k.startswith("original-from-zarr-"): - mapping = v.mapping[k] - path = mapping.store.path - files.append(os.path.realpath(path)) - return files +@singledispatch +def get_dask_backing_files(element: SpatialData | SpatialImage | MultiscaleSpatialImage | DaskDataFrame) -> list[str]: + """ + Get the backing files that appear in the Dask computational graph of an element/any element of a SpatialData object. + Parameters + ---------- + element + The element to get the backing files from. -@singledispatch -def get_backing_files(element: SpatialImage | MultiscaleSpatialImage | DaskDataFrame) -> list[str]: + Returns + ------- + List of backing files. + + Notes + ----- + It is possible for lazy objects to be constructed from multiple files. + """ raise TypeError(f"Unsupported type: {type(element)}") -@get_backing_files.register(SpatialImage) +@get_dask_backing_files.register(SpatialData) +def _(element: SpatialData) -> list[str]: + files: set[str] = set() + for e in element._gen_elements_values(): + if isinstance(e, (SpatialImage, MultiscaleSpatialImage, DaskDataFrame)): + files = files.union(get_dask_backing_files(e)) + return list(files) + + +@get_dask_backing_files.register(SpatialImage) def _(element: SpatialImage) -> list[str]: - return _get_backing_files_raster(element) + return _get_backing_files(element.data) -@get_backing_files.register(MultiscaleSpatialImage) +@get_dask_backing_files.register(MultiscaleSpatialImage) def _(element: MultiscaleSpatialImage) -> list[str]: xdata0 = next(iter(iterate_pyramid_levels(element))) - return _get_backing_files_raster(xdata0) + return _get_backing_files(xdata0.data) -@get_backing_files.register(DaskDataFrame) +@get_dask_backing_files.register(DaskDataFrame) def _(element: DaskDataFrame) -> list[str]: + return _get_backing_files(element) + + +def _get_backing_files(element: DaskArray | DaskDataFrame) -> list[str]: files = [] - layers = element.dask.layers - for k, v in layers.items(): + for k, v in element.dask.layers.items(): + if k.startswith("original-from-zarr-"): + mapping = v.mapping[k] + path = mapping.store.path + files.append(os.path.realpath(path)) if k.startswith("read-parquet-"): t = v.creation_info["args"] assert isinstance(t, tuple) diff --git a/tests/io/test_readwrite.py b/tests/io/test_readwrite.py index 3397d6e4..e629182d 100644 --- a/tests/io/test_readwrite.py +++ b/tests/io/test_readwrite.py @@ -239,12 +239,34 @@ def test_replace_transformation_on_disk_non_raster(self, shapes, points): t1 = get_transformation(SpatialData.read(f).__getattribute__(k)[elem_name]) assert type(t1) == Scale + def test_overwrite_files_without_backed_data(self, full_sdata): + with tempfile.TemporaryDirectory() as tmpdir: + f = os.path.join(tmpdir, "data.zarr") + old_data = SpatialData() + old_data.write(f) + # Since not backed, no risk of overwriting backing data. + # Should not raise "The file path specified is the same as the one used for backing." + full_sdata.write(f, overwrite=True) + + def test_not_overwrite_files_without_backed_data_but_with_dask_backed_data(self, full_sdata, points): + with tempfile.TemporaryDirectory() as tmpdir: + f = os.path.join(tmpdir, "data.zarr") + points.write(f) + points2 = SpatialData.read(f) + p = points2["points_0"] + full_sdata["points_0"] = p + with pytest.raises( + ValueError, + match="The file path specified is a parent directory of one or more files used for backing for one or ", + ): + full_sdata.write(f, overwrite=True) + def test_overwrite_files_with_backed_data(self, full_sdata): # addressing https://github.com/scverse/spatialdata/issues/137 with tempfile.TemporaryDirectory() as tmpdir: f = os.path.join(tmpdir, "data.zarr") full_sdata.write(f) - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="The file path specified is the same as the one used for backing."): full_sdata.write(f, overwrite=True) # support for overwriting backed sdata has been temporarily removed diff --git a/tests/io/test_utils.py b/tests/io/test_utils.py index 3e2cb04f..d8f86c44 100644 --- a/tests/io/test_utils.py +++ b/tests/io/test_utils.py @@ -5,12 +5,13 @@ import numpy as np import pytest from spatialdata import read_zarr, save_transformations -from spatialdata._io._utils import get_backing_files +from spatialdata._io._utils import get_dask_backing_files from spatialdata._utils import multiscale_spatial_image_from_data_tree from spatialdata.transformations import Scale, get_transformation, set_transformation def test_backing_files_points(points): + """Test the ability to identify the backing files of a dask dataframe from examining its computational graph""" with tempfile.TemporaryDirectory() as tmp_dir: f0 = os.path.join(tmp_dir, "points0.zarr") f1 = os.path.join(tmp_dir, "points1.zarr") @@ -21,7 +22,7 @@ def test_backing_files_points(points): p0 = points0.points["points_0"] p1 = points1.points["points_0"] p2 = dd.concat([p0, p1], axis=0) - files = get_backing_files(p2) + files = get_dask_backing_files(p2) expected_zarr_locations = [ os.path.realpath(os.path.join(f, "points/points_0/points.parquet")) for f in [f0, f1] ] @@ -29,6 +30,10 @@ def test_backing_files_points(points): def test_backing_files_images(images): + """ + Test the ability to identify the backing files of single scale and multiscale images from examining their + computational graph + """ with tempfile.TemporaryDirectory() as tmp_dir: f0 = os.path.join(tmp_dir, "images0.zarr") f1 = os.path.join(tmp_dir, "images1.zarr") @@ -41,7 +46,7 @@ def test_backing_files_images(images): im0 = images0.images["image2d"] im1 = images1.images["image2d"] im2 = im0 + im1 - files = get_backing_files(im2) + files = get_dask_backing_files(im2) expected_zarr_locations = [os.path.realpath(os.path.join(f, "images/image2d")) for f in [f0, f1]] assert set(files) == set(expected_zarr_locations) @@ -49,13 +54,17 @@ def test_backing_files_images(images): im3 = images0.images["image2d_multiscale"] im4 = images1.images["image2d_multiscale"] im5 = multiscale_spatial_image_from_data_tree(im3 + im4) - files = get_backing_files(im5) + files = get_dask_backing_files(im5) expected_zarr_locations = [os.path.realpath(os.path.join(f, "images/image2d_multiscale")) for f in [f0, f1]] assert set(files) == set(expected_zarr_locations) # TODO: this function here below is very similar to the above, unify the test with the above or delete this todo def test_backing_files_labels(labels): + """ + Test the ability to identify the backing files of single scale and multiscale labels from examining their + computational graph + """ with tempfile.TemporaryDirectory() as tmp_dir: f0 = os.path.join(tmp_dir, "labels0.zarr") f1 = os.path.join(tmp_dir, "labels1.zarr") @@ -68,7 +77,7 @@ def test_backing_files_labels(labels): im0 = labels0.labels["labels2d"] im1 = labels1.labels["labels2d"] im2 = im0 + im1 - files = get_backing_files(im2) + files = get_dask_backing_files(im2) expected_zarr_locations = [os.path.realpath(os.path.join(f, "labels/labels2d")) for f in [f0, f1]] assert set(files) == set(expected_zarr_locations) @@ -76,11 +85,37 @@ def test_backing_files_labels(labels): im3 = labels0.labels["labels2d_multiscale"] im4 = labels1.labels["labels2d_multiscale"] im5 = multiscale_spatial_image_from_data_tree(im3 + im4) - files = get_backing_files(im5) + files = get_dask_backing_files(im5) expected_zarr_locations = [os.path.realpath(os.path.join(f, "labels/labels2d_multiscale")) for f in [f0, f1]] assert set(files) == set(expected_zarr_locations) +def test_backing_files_combining_points_and_images(points, images): + """ + Test the ability to identify the backing files of an object that depends both on dask dataframes and dask arrays + from examining its computational graph + """ + with tempfile.TemporaryDirectory() as tmp_dir: + f0 = os.path.join(tmp_dir, "points0.zarr") + f1 = os.path.join(tmp_dir, "images1.zarr") + points.write(f0) + images.write(f1) + points0 = read_zarr(f0) + images1 = read_zarr(f1) + + p0 = points0.points["points_0"] + im1 = images1.images["image2d"] + v = p0["x"].loc[0].values + v.compute_chunk_sizes() + im2 = v + im1 + files = get_dask_backing_files(im2) + expected_zarr_locations = [ + os.path.realpath(os.path.join(f0, "points/points_0/points.parquet")), + os.path.realpath(os.path.join(f1, "images/image2d")), + ] + assert set(files) == set(expected_zarr_locations) + + def test_save_transformations(labels): with tempfile.TemporaryDirectory() as tmp_dir: f0 = os.path.join(tmp_dir, "labels0.zarr")