diff --git a/src/spatialdata/_core/query/spatial_query.py b/src/spatialdata/_core/query/spatial_query.py index 0aa4f3fc..b6874193 100644 --- a/src/spatialdata/_core/query/spatial_query.py +++ b/src/spatialdata/_core/query/spatial_query.py @@ -563,6 +563,8 @@ def _( if 0 in query_result.shape: return None assert isinstance(query_result, SpatialImage) + # rechunk the data to avoid irregular chunks + image = image.chunk("auto") else: assert isinstance(image, MultiscaleSpatialImage) assert isinstance(query_result, DataTree) @@ -579,6 +581,9 @@ def _( else: d[k] = xdata query_result = MultiscaleSpatialImage.from_dict(d) + # rechunk the data to avoid irregular chunks + for scale in query_result: + query_result[scale]["image"] = query_result[scale]["image"].chunk("auto") query_result = compute_coordinates(query_result) # the bounding box, mapped back to the intrinsic coordinate system is a set of points. The bounding box of these diff --git a/tests/io/test_readwrite.py b/tests/io/test_readwrite.py index fcbb87ae..81d43864 100644 --- a/tests/io/test_readwrite.py +++ b/tests/io/test_readwrite.py @@ -14,7 +14,7 @@ from spatial_image import SpatialImage from spatialdata import SpatialData, read_zarr from spatialdata._io._utils import _are_directories_identical -from spatialdata.models import TableModel +from spatialdata.models import Image2DModel, TableModel from spatialdata.transformations.operations import ( get_transformation, set_transformation, @@ -319,3 +319,22 @@ def test_io_table(shapes): shapes2.table = adata assert shapes2.table is not None assert shapes2.table.shape == (5, 10) + + +def test_bug_rechunking_after_queried_raster(): + # https://github.com/scverse/spatialdata-io/issues/117 + ## + single_scale = Image2DModel.parse(RNG.random((100, 10, 10)), chunks=(5, 5, 5)) + multi_scale = Image2DModel.parse(RNG.random((100, 10, 10)), scale_factors=[2, 2], chunks=(5, 5, 5)) + images = {"single_scale": single_scale, "multi_scale": multi_scale} + sdata = SpatialData(images=images) + queried = sdata.query.bounding_box( + axes=("x", "y"), min_coordinate=[2, 5], max_coordinate=[12, 12], target_coordinate_system="global" + ) + with tempfile.TemporaryDirectory() as tmpdir: + f = os.path.join(tmpdir, "data.zarr") + queried.write(f) + + ## + + pass