diff --git a/examples/data_loading/plot_tilefetcher_loader.py b/examples/data_loading/plot_tilefetcher_loader.py index 29a483afa..81ca12be6 100644 --- a/examples/data_loading/plot_tilefetcher_loader.py +++ b/examples/data_loading/plot_tilefetcher_loader.py @@ -103,5 +103,6 @@ def get_tile( rounds=range(num_r), chs=range(num_c), zplanes=range(num_z), + group_by=(Axes.ROUND, Axes.CH), ) print(repr(stack)) diff --git a/starfish/core/imagestack/imagestack.py b/starfish/core/imagestack/imagestack.py index f23e461dc..8ad58bf16 100644 --- a/starfish/core/imagestack/imagestack.py +++ b/starfish/core/imagestack/imagestack.py @@ -10,6 +10,7 @@ from typing import ( Any, Callable, + Collection, Hashable, Iterable, Iterator, @@ -91,7 +92,7 @@ class ImageStack: the shape of the image tensor by categorical index (channels, imaging rounds, z-layers) """ - def __init__(self, data: xr.DataArray, tile_data: Optional[TileCollectionData]=None): + def __init__(self, data: xr.DataArray, tile_data: TileCollectionData): self._data = data self._data_loaded = False self._tile_data = tile_data @@ -214,9 +215,19 @@ def load_by_selector(selector): return tile_dtype - with ThreadPoolExecutor() as tpe: - # gather all the data types of the tiles to ensure that they are compatible. - tile_dtypes = set(tpe.map(load_by_selector, all_selectors)) + if len(self._tile_data.group_by) == 0: + with ThreadPoolExecutor() as tpe: + # gather all the data types of the tiles to ensure that they are compatible. + tile_dtypes = set(tpe.map(load_by_selector, all_selectors)) + else: + tile_dtypes = set() + group_by_selectors = list(self._iter_axes(self._tile_data.group_by)) + non_group_by_selectors = list(self._iter_axes( + {Axes.ROUND, Axes.CH, Axes.ZPLANE} - self._tile_data.group_by)) + for group_by_selector in group_by_selectors: + for non_group_by_selector in non_group_by_selectors: + tile_dtypes.add(load_by_selector( + {**group_by_selector, **non_group_by_selector})) pbar.close() tile_dtype_kinds = set(tile_dtype.kind for tile_dtype in tile_dtypes) @@ -269,7 +280,7 @@ def from_tilefetcher( rounds: Sequence[int], chs: Sequence[int], zplanes: Sequence[int], - axes_order: Optional[Sequence[Axes]] = None, + group_by: Optional[Collection[Axes]] = None, crop_parameters: Optional[CropParameters]=None, ) -> "ImageStack": """ @@ -289,21 +300,11 @@ def from_tilefetcher( The channels to include in this ImageStack. zplanes : Sequence[int] The zplanes to include in this ImageStack. - axes_order : Optional[Sequence[Axes]] - Ordering for which axes vary, in order of the slowest changing axis to the fastest. For - instance, if the order is (ROUND, Z, CH) and each dimension has size 2, then the - sequence is: - - (ROUND=0, CH=0, Z=0) - (ROUND=0, CH=1, Z=0) - (ROUND=0, CH=0, Z=1) - (ROUND=0, CH=1, Z=1) - (ROUND=1, CH=0, Z=0) - (ROUND=1, CH=1, Z=0) - (ROUND=1, CH=0, Z=1) - (ROUND=1, CH=1, Z=1) - - (default = (Axes.Z, Axes.ROUND, Axes.CH)) + group_by : Optional[Set[Axes]] + Axes to load the data by. If an axis is present in this list, all the data for a given + value along that axis will be loaded concurrently. For example, if group_by is + (Axes.ROUND, Axes.CH), then all the data for ROUND=2, CH=1 will be loaded before we + progress to ROUND=3, CH=1. crop_parameters : Optional[CropParameters] If cropping of the data is desired, it should be specified here. @@ -315,7 +316,7 @@ def from_tilefetcher( from starfish.core.imagestack.parser.tilefetcher import TileFetcherData tile_data: TileCollectionData = TileFetcherData( - tilefetcher, tile_shape, fov, rounds, chs, zplanes, axes_order) + tilefetcher, tile_shape, fov, rounds, chs, zplanes, group_by) if crop_parameters is not None: tile_data = CroppedTileCollectionData(tile_data, crop_parameters) return ImageStack.from_tile_collection_data(tile_data) diff --git a/starfish/core/imagestack/parser/_tiledata.py b/starfish/core/imagestack/parser/_tiledata.py index 03d9fbe17..70a07eafa 100644 --- a/starfish/core/imagestack/parser/_tiledata.py +++ b/starfish/core/imagestack/parser/_tiledata.py @@ -1,4 +1,4 @@ -from typing import Collection, Mapping +from typing import Collection, Mapping, Set import numpy as np @@ -52,6 +52,11 @@ def tile_shape(self) -> Mapping[Axes, int]: """Returns the shape of a tile.""" raise NotImplementedError() + @property + def group_by(self) -> Set[Axes]: + """Returns the axes to group by when we load the data.""" + raise NotImplementedError() + @property def extras(self) -> dict: """Returns the extras metadata for the TileSet.""" diff --git a/starfish/core/imagestack/parser/crop.py b/starfish/core/imagestack/parser/crop.py index cb18d8c91..c38e6007b 100644 --- a/starfish/core/imagestack/parser/crop.py +++ b/starfish/core/imagestack/parser/crop.py @@ -1,5 +1,5 @@ from collections import OrderedDict -from typing import Collection, List, Mapping, MutableSequence, Optional, Tuple, Union +from typing import Collection, List, Mapping, MutableSequence, Optional, Set, Tuple, Union import numpy as np from slicedimage import Tile, TileSet @@ -280,6 +280,11 @@ def __getitem__(self, tilekey: TileKey) -> dict: def keys(self) -> Collection[TileKey]: return self.crop_parameters.filter_tilekeys(self.backing_tile_collection_data.keys()) + @property + def group_by(self) -> Set[Axes]: + """Returns the axes to group by when we load the data.""" + return self.backing_tile_collection_data.group_by + @property def tile_shape(self) -> Mapping[Axes, int]: return self.crop_parameters.crop_shape(self.backing_tile_collection_data.tile_shape) diff --git a/starfish/core/imagestack/parser/numpy/__init__.py b/starfish/core/imagestack/parser/numpy/__init__.py index aab914a73..b68d4788d 100644 --- a/starfish/core/imagestack/parser/numpy/__init__.py +++ b/starfish/core/imagestack/parser/numpy/__init__.py @@ -9,6 +9,7 @@ MutableSequence, Optional, Sequence, + Set, ) import numpy as np @@ -100,6 +101,11 @@ def keys(self) -> Collection[TileKey]: return keys + @property + def group_by(self) -> Set[Axes]: + """Returns the axes to group by when we load the data.""" + return set() + @property def tile_shape(self) -> Mapping[Axes, int]: return {Axes.Y: self.data.shape[-2], Axes.X: self.data.shape[-1]} diff --git a/starfish/core/imagestack/parser/tilefetcher/_parser.py b/starfish/core/imagestack/parser/tilefetcher/_parser.py index c062bf217..59958189c 100644 --- a/starfish/core/imagestack/parser/tilefetcher/_parser.py +++ b/starfish/core/imagestack/parser/tilefetcher/_parser.py @@ -1,11 +1,10 @@ """ This module wraps a TileFetcher to provide the data to instantiate an ImageStack. """ -from typing import Collection, Mapping, MutableMapping, Optional, Sequence +from typing import Collection, Mapping, MutableMapping, Optional, Sequence, Set import numpy as np -from starfish.core.experiment.builder.orderediterator import join_axes_labels, ordered_iterator from starfish.core.experiment.builder.providers import FetchedTile, TileFetcher from starfish.core.imagestack.parser import TileCollectionData, TileData, TileKey from starfish.core.imagestack.physical_coordinates import _get_physical_coordinates_of_z_plane @@ -78,7 +77,7 @@ def __init__( rounds: Sequence[int], chs: Sequence[int], zplanes: Sequence[int], - axes_order: Optional[Sequence[Axes]] = None, + group_by: Optional[Collection[Axes]] = None, ) -> None: self._tile_fetcher = tile_fetcher self._tile_shape = tile_shape @@ -86,9 +85,7 @@ def __init__( self._rounds = rounds self._chs = chs self._zplanes = zplanes - if axes_order is None: - axes_order = Axes.ZPLANE, Axes.ROUND, Axes.CH - self._axes_order = axes_order + self._group_by = set(group_by) if group_by is not None else set() def __getitem__(self, tilekey: TileKey) -> dict: """Returns the extras metadata for a given tile, addressed by its TileKey""" @@ -96,13 +93,18 @@ def __getitem__(self, tilekey: TileKey) -> dict: def keys(self) -> Collection[TileKey]: """Returns a Collection of the TileKey's for all the tiles.""" - axes_sizes = join_axes_labels( - self._axes_order, rounds=self._rounds, chs=self._chs, zplanes=self._zplanes) return [ - TileKey(round=selector[Axes.ROUND], ch=selector[Axes.CH], zplane=selector[Axes.ZPLANE]) - for selector in ordered_iterator(axes_sizes) + TileKey(round=round_label, ch=ch_label, zplane=zplane_label) + for round_label in self._rounds + for ch_label in self._chs + for zplane_label in self._zplanes ] + @property + def group_by(self) -> Set[Axes]: + """Returns the axes to group by when we load the data.""" + return self._group_by + @property def tile_shape(self) -> Mapping[Axes, int]: return self._tile_shape diff --git a/starfish/core/imagestack/parser/tileset/_parser.py b/starfish/core/imagestack/parser/tileset/_parser.py index 76e63ac61..e47ec7b11 100644 --- a/starfish/core/imagestack/parser/tileset/_parser.py +++ b/starfish/core/imagestack/parser/tileset/_parser.py @@ -1,7 +1,7 @@ """ This module parses and retains the extras metadata attached to TileSet extras. """ -from typing import Collection, Mapping, MutableMapping, Tuple +from typing import Collection, Mapping, MutableMapping, Set, Tuple import numpy as np from slicedimage import Tile, TileSet @@ -103,6 +103,11 @@ def keys(self) -> Collection[TileKey]: """Returns a Collection of the TileKey's for all the tiles.""" return self.tiles.keys() + @property + def group_by(self) -> Set[Axes]: + """Returns the axes to group by when we load the data.""" + return set() + @property def tile_shape(self) -> Mapping[Axes, int]: return self._tile_shape diff --git a/starfish/core/imagestack/test/test_from_tilefetcher.py b/starfish/core/imagestack/test/test_from_tilefetcher.py index 1944cbe0a..2189815b1 100644 --- a/starfish/core/imagestack/test/test_from_tilefetcher.py +++ b/starfish/core/imagestack/test/test_from_tilefetcher.py @@ -1,6 +1,7 @@ from typing import Mapping, Union import numpy as np +import pytest from starfish.core.experiment.builder.builder import tile_fetcher_factory from starfish.core.experiment.builder.test.factories.unique_tiles import unique_data, UniqueTiles @@ -18,7 +19,18 @@ def coordinates(self) -> Mapping[Union[str, Coordinates], CoordinateValue]: } +@pytest.mark.parametrize( + "group_by", + [ + None, + set((Axes.ROUND,)), + set((Axes.ROUND, Axes.CH)), + set((Axes.ROUND, Axes.CH, Axes.ZPLANE)), + set((Axes.ZPLANE,)), + ] +) def test_from_tilefetcher( + group_by, rounds=(0, 1, 2, 3), chs=(0, 1, 3), zplanes=(0, 1), @@ -42,6 +54,7 @@ def test_from_tilefetcher( rounds=rounds, chs=chs, zplanes=zplanes, + group_by=group_by, ) assert stack.shape == {