Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow force-overwriting existing files (non-backing) #344

Merged
merged 12 commits into from
Jan 6, 2024
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ and this project adheres to [Semantic Versioning][].

## [0.0.x] - tbd

### Minor

- improved usability and robustness of sdata.write() when overwrite=True @aeisenbarth

### Added

### Fixed
Expand Down
2 changes: 2 additions & 0 deletions docs/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ Operations on `SpatialData` objects.
get_extent
match_table_to_element
concatenate
transform
rasterize
aggregate
```
Expand Down Expand Up @@ -133,4 +134,5 @@ The transformations that can be defined between elements and coordinate systems

read_zarr
save_transformations
get_dask_backing_files
```
3 changes: 2 additions & 1 deletion src/spatialdata/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
"read_zarr",
"unpad_raster",
"save_transformations",
"get_dask_backing_files",
]

from spatialdata import dataloader, models, transformations
Expand All @@ -40,6 +41,6 @@
from spatialdata._core.query.relational_query import get_values, match_table_to_element
from spatialdata._core.query.spatial_query import bounding_box_query, polygon_query
from spatialdata._core.spatialdata import SpatialData
from spatialdata._io._utils import save_transformations
from spatialdata._io._utils import get_dask_backing_files, save_transformations
from spatialdata._io.io_zarr import read_zarr
from spatialdata._utils import unpad_raster
24 changes: 17 additions & 7 deletions src/spatialdata/_core/spatialdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,6 +571,7 @@ def write(
consolidate_metadata: bool = True,
) -> None:
from spatialdata._io import write_image, write_labels, write_points, write_shapes, write_table
from spatialdata._io._utils import get_dask_backing_files

"""Write the SpatialData object to Zarr."""
if isinstance(file_path, str):
Expand All @@ -583,21 +584,30 @@ def write(
# old code to support overwriting the backing file
# target_path = None
# tmp_zarr_file = None

if os.path.exists(file_path):
if parse_url(file_path, mode="r") is None:
raise ValueError(
"The target file path specified already exists, and it has been detected to not be "
"a Zarr store. Overwriting non-Zarr stores is not supported to prevent accidental "
"data loss."
)
if not overwrite and str(self.path) != str(file_path):
if not overwrite:
raise ValueError("The Zarr store already exists. Use `overwrite=True` to overwrite the store.")
raise ValueError(
"The file path specified is the same as the one used for backing. "
"Overwriting the backing file is not supported to prevent accidental data loss."
"We are discussing how to support this use case in the future, if you would like us to "
"support it please leave a comment on https://github.com/scverse/spatialdata/pull/138"
)
if self.is_backed() and str(self.path) == str(file_path):
raise ValueError(
"The file path specified is the same as the one used for backing. "
"Overwriting the backing file is not supported to prevent accidental data loss."
"We are discussing how to support this use case in the future, if you would like us to "
"support it please leave a comment on https://github.com/scverse/spatialdata/pull/138"
)
if any(Path(fp).resolve().is_relative_to(file_path.resolve()) for fp in get_dask_backing_files(self)):
raise ValueError(
"The file path specified is a parent directory of one or more files used for backing for one or "
"more elements in the SpatialData object. You can either load every element of the SpatialData "
"object in memory, or save the current spatialdata object to a different path."
)

# old code to support overwriting the backing file
# else:
# target_path = tempfile.TemporaryDirectory()
Expand Down
2 changes: 2 additions & 0 deletions src/spatialdata/_io/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from spatialdata._io._utils import get_dask_backing_files
from spatialdata._io.format import SpatialDataFormatV01
from spatialdata._io.io_points import write_points
from spatialdata._io.io_raster import write_image, write_labels
Expand All @@ -11,4 +12,5 @@
"write_shapes",
"write_table",
"SpatialDataFormatV01",
"get_dask_backing_files",
]
58 changes: 40 additions & 18 deletions src/spatialdata/_io/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@
from typing import Any

import zarr
from dask.array.core import Array as DaskArray
from dask.dataframe.core import DataFrame as DaskDataFrame
from multiscale_spatial_image import MultiscaleSpatialImage
from ome_zarr.format import Format
from ome_zarr.writer import _get_valid_axes
from spatial_image import SpatialImage
from xarray import DataArray

from spatialdata._core.spatialdata import SpatialData
from spatialdata._utils import iterate_pyramid_levels
Expand Down Expand Up @@ -203,37 +203,59 @@ def _compare_sdata_on_disk(a: SpatialData, b: SpatialData) -> bool:
return _are_directories_identical(os.path.join(tmpdir, "a.zarr"), os.path.join(tmpdir, "b.zarr"))


def _get_backing_files_raster(raster: DataArray) -> list[str]:
files = []
for k, v in raster.data.dask.layers.items():
if k.startswith("original-from-zarr-"):
mapping = v.mapping[k]
path = mapping.store.path
files.append(os.path.realpath(path))
return files
@singledispatch
def get_dask_backing_files(element: SpatialData | SpatialImage | MultiscaleSpatialImage | DaskDataFrame) -> list[str]:
"""
Get the backing files that appear in the Dask computational graph of an element/any element of a SpatialData object.

Parameters
----------
element
The element to get the backing files from.

@singledispatch
def get_backing_files(element: SpatialImage | MultiscaleSpatialImage | DaskDataFrame) -> list[str]:
Returns
-------
List of backing files.

Notes
-----
It is possible for lazy objects to be constructed from multiple files.
"""
raise TypeError(f"Unsupported type: {type(element)}")


@get_backing_files.register(SpatialImage)
@get_dask_backing_files.register(SpatialData)
def _(element: SpatialData) -> list[str]:
files: set[str] = set()
for e in element._gen_elements_values():
if isinstance(e, (SpatialImage, MultiscaleSpatialImage, DaskDataFrame)):
files = files.union(get_dask_backing_files(e))
return list(files)


@get_dask_backing_files.register(SpatialImage)
def _(element: SpatialImage) -> list[str]:
return _get_backing_files_raster(element)
return _get_backing_files(element.data)


@get_backing_files.register(MultiscaleSpatialImage)
@get_dask_backing_files.register(MultiscaleSpatialImage)
def _(element: MultiscaleSpatialImage) -> list[str]:
xdata0 = next(iter(iterate_pyramid_levels(element)))
return _get_backing_files_raster(xdata0)
return _get_backing_files(xdata0.data)


@get_backing_files.register(DaskDataFrame)
@get_dask_backing_files.register(DaskDataFrame)
def _(element: DaskDataFrame) -> list[str]:
return _get_backing_files(element)


def _get_backing_files(element: DaskArray | DaskDataFrame) -> list[str]:
files = []
layers = element.dask.layers
for k, v in layers.items():
for k, v in element.dask.layers.items():
if k.startswith("original-from-zarr-"):
mapping = v.mapping[k]
path = mapping.store.path
files.append(os.path.realpath(path))
if k.startswith("read-parquet-"):
t = v.creation_info["args"]
assert isinstance(t, tuple)
Expand Down
24 changes: 23 additions & 1 deletion tests/io/test_readwrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,12 +239,34 @@ def test_replace_transformation_on_disk_non_raster(self, shapes, points):
t1 = get_transformation(SpatialData.read(f).__getattribute__(k)[elem_name])
assert type(t1) == Scale

def test_overwrite_files_without_backed_data(self, full_sdata):
with tempfile.TemporaryDirectory() as tmpdir:
f = os.path.join(tmpdir, "data.zarr")
old_data = SpatialData()
old_data.write(f)
# Since not backed, no risk of overwriting backing data.
# Should not raise "The file path specified is the same as the one used for backing."
full_sdata.write(f, overwrite=True)

def test_not_overwrite_files_without_backed_data_but_with_dask_backed_data(self, full_sdata, points):
with tempfile.TemporaryDirectory() as tmpdir:
f = os.path.join(tmpdir, "data.zarr")
points.write(f)
points2 = SpatialData.read(f)
p = points2["points_0"]
full_sdata["points_0"] = p
with pytest.raises(
ValueError,
match="The file path specified is a parent directory of one or more files used for backing for one or ",
):
full_sdata.write(f, overwrite=True)

def test_overwrite_files_with_backed_data(self, full_sdata):
# addressing https://github.com/scverse/spatialdata/issues/137
with tempfile.TemporaryDirectory() as tmpdir:
f = os.path.join(tmpdir, "data.zarr")
full_sdata.write(f)
with pytest.raises(ValueError):
with pytest.raises(ValueError, match="The file path specified is the same as the one used for backing."):
full_sdata.write(f, overwrite=True)

# support for overwriting backed sdata has been temporarily removed
Expand Down
47 changes: 41 additions & 6 deletions tests/io/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@
import numpy as np
import pytest
from spatialdata import read_zarr, save_transformations
from spatialdata._io._utils import get_backing_files
from spatialdata._io._utils import get_dask_backing_files
from spatialdata._utils import multiscale_spatial_image_from_data_tree
from spatialdata.transformations import Scale, get_transformation, set_transformation


def test_backing_files_points(points):
"""Test the ability to identify the backing files of a dask dataframe from examining its computational graph"""
with tempfile.TemporaryDirectory() as tmp_dir:
f0 = os.path.join(tmp_dir, "points0.zarr")
f1 = os.path.join(tmp_dir, "points1.zarr")
Expand All @@ -21,14 +22,18 @@ def test_backing_files_points(points):
p0 = points0.points["points_0"]
p1 = points1.points["points_0"]
p2 = dd.concat([p0, p1], axis=0)
files = get_backing_files(p2)
files = get_dask_backing_files(p2)
expected_zarr_locations = [
os.path.realpath(os.path.join(f, "points/points_0/points.parquet")) for f in [f0, f1]
]
assert set(files) == set(expected_zarr_locations)


def test_backing_files_images(images):
"""
Test the ability to identify the backing files of single scale and multiscale images from examining their
computational graph
"""
with tempfile.TemporaryDirectory() as tmp_dir:
f0 = os.path.join(tmp_dir, "images0.zarr")
f1 = os.path.join(tmp_dir, "images1.zarr")
Expand All @@ -41,21 +46,25 @@ def test_backing_files_images(images):
im0 = images0.images["image2d"]
im1 = images1.images["image2d"]
im2 = im0 + im1
files = get_backing_files(im2)
files = get_dask_backing_files(im2)
expected_zarr_locations = [os.path.realpath(os.path.join(f, "images/image2d")) for f in [f0, f1]]
assert set(files) == set(expected_zarr_locations)

# multiscale
im3 = images0.images["image2d_multiscale"]
im4 = images1.images["image2d_multiscale"]
im5 = multiscale_spatial_image_from_data_tree(im3 + im4)
files = get_backing_files(im5)
files = get_dask_backing_files(im5)
expected_zarr_locations = [os.path.realpath(os.path.join(f, "images/image2d_multiscale")) for f in [f0, f1]]
assert set(files) == set(expected_zarr_locations)


# TODO: this function here below is very similar to the above, unify the test with the above or delete this todo
def test_backing_files_labels(labels):
"""
Test the ability to identify the backing files of single scale and multiscale labels from examining their
computational graph
"""
with tempfile.TemporaryDirectory() as tmp_dir:
f0 = os.path.join(tmp_dir, "labels0.zarr")
f1 = os.path.join(tmp_dir, "labels1.zarr")
Expand All @@ -68,19 +77,45 @@ def test_backing_files_labels(labels):
im0 = labels0.labels["labels2d"]
im1 = labels1.labels["labels2d"]
im2 = im0 + im1
files = get_backing_files(im2)
files = get_dask_backing_files(im2)
expected_zarr_locations = [os.path.realpath(os.path.join(f, "labels/labels2d")) for f in [f0, f1]]
assert set(files) == set(expected_zarr_locations)

# multiscale
im3 = labels0.labels["labels2d_multiscale"]
im4 = labels1.labels["labels2d_multiscale"]
im5 = multiscale_spatial_image_from_data_tree(im3 + im4)
files = get_backing_files(im5)
files = get_dask_backing_files(im5)
expected_zarr_locations = [os.path.realpath(os.path.join(f, "labels/labels2d_multiscale")) for f in [f0, f1]]
assert set(files) == set(expected_zarr_locations)


def test_backing_files_combining_points_and_images(points, images):
"""
Test the ability to identify the backing files of an object that depends both on dask dataframes and dask arrays
from examining its computational graph
"""
with tempfile.TemporaryDirectory() as tmp_dir:
f0 = os.path.join(tmp_dir, "points0.zarr")
f1 = os.path.join(tmp_dir, "images1.zarr")
points.write(f0)
images.write(f1)
points0 = read_zarr(f0)
images1 = read_zarr(f1)

p0 = points0.points["points_0"]
im1 = images1.images["image2d"]
v = p0["x"].loc[0].values
v.compute_chunk_sizes()
im2 = v + im1
files = get_dask_backing_files(im2)
expected_zarr_locations = [
os.path.realpath(os.path.join(f0, "points/points_0/points.parquet")),
os.path.realpath(os.path.join(f1, "images/image2d")),
]
assert set(files) == set(expected_zarr_locations)


def test_save_transformations(labels):
with tempfile.TemporaryDirectory() as tmp_dir:
f0 = os.path.join(tmp_dir, "labels0.zarr")
Expand Down
Loading