Skip to content

Commit

Permalink
Add group_by for tilefetcher-based ImageStack construction (#1796)
Browse files Browse the repository at this point in the history
When we construct an ImageStack from a tilefetcher, we now allow accept a `group_by` parameter which controls how we load the data.  If an axis is in the `group_by`, then we load all the data for any given value for that axis before moving on to another value for that axis.

For instance, if Axes.ROUND is in `group_by`, then we will load all the data for round=0 before we load the data for round=1.

Test plan: Added tests to verify that the loader works with a variety of `group_by`s.  Dan Goodwin's 3D data loads in 100seconds now instead of 10 minutes.
  • Loading branch information
Tony Tung authored Feb 7, 2020
1 parent 4d652cf commit bc60209
Show file tree
Hide file tree
Showing 8 changed files with 72 additions and 34 deletions.
1 change: 1 addition & 0 deletions examples/data_loading/plot_tilefetcher_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,5 +103,6 @@ def get_tile(
rounds=range(num_r),
chs=range(num_c),
zplanes=range(num_z),
group_by=(Axes.ROUND, Axes.CH),
)
print(repr(stack))
43 changes: 22 additions & 21 deletions starfish/core/imagestack/imagestack.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
Any,
BinaryIO,
Callable,
Collection,
Hashable,
Iterable,
Iterator,
Expand Down Expand Up @@ -92,7 +93,7 @@ class ImageStack:
the shape of the image tensor by categorical index (channels, imaging rounds, z-layers)
"""

def __init__(self, data: xr.DataArray, tile_data: Optional[TileCollectionData]=None):
def __init__(self, data: xr.DataArray, tile_data: TileCollectionData):
self._data = data
self._data_loaded = False
self._tile_data = tile_data
Expand Down Expand Up @@ -215,9 +216,19 @@ def load_by_selector(selector):

return tile_dtype

with ThreadPoolExecutor() as tpe:
# gather all the data types of the tiles to ensure that they are compatible.
tile_dtypes = set(tpe.map(load_by_selector, all_selectors))
if len(self._tile_data.group_by) == 0:
with ThreadPoolExecutor() as tpe:
# gather all the data types of the tiles to ensure that they are compatible.
tile_dtypes = set(tpe.map(load_by_selector, all_selectors))
else:
tile_dtypes = set()
group_by_selectors = list(self._iter_axes(self._tile_data.group_by))
non_group_by_selectors = list(self._iter_axes(
{Axes.ROUND, Axes.CH, Axes.ZPLANE} - self._tile_data.group_by))
for group_by_selector in group_by_selectors:
for non_group_by_selector in non_group_by_selectors:
tile_dtypes.add(load_by_selector(
{**group_by_selector, **non_group_by_selector}))
pbar.close()

tile_dtype_kinds = set(tile_dtype.kind for tile_dtype in tile_dtypes)
Expand Down Expand Up @@ -270,7 +281,7 @@ def from_tilefetcher(
rounds: Sequence[int],
chs: Sequence[int],
zplanes: Sequence[int],
axes_order: Optional[Sequence[Axes]] = None,
group_by: Optional[Collection[Axes]] = None,
crop_parameters: Optional[CropParameters]=None,
) -> "ImageStack":
"""
Expand All @@ -290,21 +301,11 @@ def from_tilefetcher(
The channels to include in this ImageStack.
zplanes : Sequence[int]
The zplanes to include in this ImageStack.
axes_order : Optional[Sequence[Axes]]
Ordering for which axes vary, in order of the slowest changing axis to the fastest. For
instance, if the order is (ROUND, Z, CH) and each dimension has size 2, then the
sequence is:
(ROUND=0, CH=0, Z=0)
(ROUND=0, CH=1, Z=0)
(ROUND=0, CH=0, Z=1)
(ROUND=0, CH=1, Z=1)
(ROUND=1, CH=0, Z=0)
(ROUND=1, CH=1, Z=0)
(ROUND=1, CH=0, Z=1)
(ROUND=1, CH=1, Z=1)
(default = (Axes.Z, Axes.ROUND, Axes.CH))
group_by : Optional[Set[Axes]]
Axes to load the data by. If an axis is present in this list, all the data for a given
value along that axis will be loaded concurrently. For example, if group_by is
(Axes.ROUND, Axes.CH), then all the data for ROUND=2, CH=1 will be loaded before we
progress to ROUND=3, CH=1.
crop_parameters : Optional[CropParameters]
If cropping of the data is desired, it should be specified here.
Expand All @@ -316,7 +317,7 @@ def from_tilefetcher(
from starfish.core.imagestack.parser.tilefetcher import TileFetcherData

tile_data: TileCollectionData = TileFetcherData(
tilefetcher, tile_shape, fov, rounds, chs, zplanes, axes_order)
tilefetcher, tile_shape, fov, rounds, chs, zplanes, group_by)
if crop_parameters is not None:
tile_data = CroppedTileCollectionData(tile_data, crop_parameters)
return ImageStack.from_tile_collection_data(tile_data)
Expand Down
7 changes: 6 additions & 1 deletion starfish/core/imagestack/parser/_tiledata.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Collection, Mapping
from typing import Collection, Mapping, Set

import numpy as np

Expand Down Expand Up @@ -52,6 +52,11 @@ def tile_shape(self) -> Mapping[Axes, int]:
"""Returns the shape of a tile."""
raise NotImplementedError()

@property
def group_by(self) -> Set[Axes]:
"""Returns the axes to group by when we load the data."""
raise NotImplementedError()

@property
def extras(self) -> dict:
"""Returns the extras metadata for the TileSet."""
Expand Down
7 changes: 6 additions & 1 deletion starfish/core/imagestack/parser/crop.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections import OrderedDict
from typing import Collection, List, Mapping, MutableSequence, Optional, Tuple, Union
from typing import Collection, List, Mapping, MutableSequence, Optional, Set, Tuple, Union

import numpy as np
from slicedimage import Tile, TileSet
Expand Down Expand Up @@ -280,6 +280,11 @@ def __getitem__(self, tilekey: TileKey) -> dict:
def keys(self) -> Collection[TileKey]:
return self.crop_parameters.filter_tilekeys(self.backing_tile_collection_data.keys())

@property
def group_by(self) -> Set[Axes]:
"""Returns the axes to group by when we load the data."""
return self.backing_tile_collection_data.group_by

@property
def tile_shape(self) -> Mapping[Axes, int]:
return self.crop_parameters.crop_shape(self.backing_tile_collection_data.tile_shape)
Expand Down
6 changes: 6 additions & 0 deletions starfish/core/imagestack/parser/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
MutableSequence,
Optional,
Sequence,
Set,
)

import numpy as np
Expand Down Expand Up @@ -100,6 +101,11 @@ def keys(self) -> Collection[TileKey]:

return keys

@property
def group_by(self) -> Set[Axes]:
"""Returns the axes to group by when we load the data."""
return set()

@property
def tile_shape(self) -> Mapping[Axes, int]:
return {Axes.Y: self.data.shape[-2], Axes.X: self.data.shape[-1]}
Expand Down
22 changes: 12 additions & 10 deletions starfish/core/imagestack/parser/tilefetcher/_parser.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
"""
This module wraps a TileFetcher to provide the data to instantiate an ImageStack.
"""
from typing import Collection, Mapping, MutableMapping, Optional, Sequence
from typing import Collection, Mapping, MutableMapping, Optional, Sequence, Set

import numpy as np

from starfish.core.experiment.builder.orderediterator import join_axes_labels, ordered_iterator
from starfish.core.experiment.builder.providers import FetchedTile, TileFetcher
from starfish.core.imagestack.parser import TileCollectionData, TileData, TileKey
from starfish.core.imagestack.physical_coordinates import _get_physical_coordinates_of_z_plane
Expand Down Expand Up @@ -78,31 +77,34 @@ def __init__(
rounds: Sequence[int],
chs: Sequence[int],
zplanes: Sequence[int],
axes_order: Optional[Sequence[Axes]] = None,
group_by: Optional[Collection[Axes]] = None,
) -> None:
self._tile_fetcher = tile_fetcher
self._tile_shape = tile_shape
self._fov = fov
self._rounds = rounds
self._chs = chs
self._zplanes = zplanes
if axes_order is None:
axes_order = Axes.ZPLANE, Axes.ROUND, Axes.CH
self._axes_order = axes_order
self._group_by = set(group_by) if group_by is not None else set()

def __getitem__(self, tilekey: TileKey) -> dict:
"""Returns the extras metadata for a given tile, addressed by its TileKey"""
return {}

def keys(self) -> Collection[TileKey]:
"""Returns a Collection of the TileKey's for all the tiles."""
axes_sizes = join_axes_labels(
self._axes_order, rounds=self._rounds, chs=self._chs, zplanes=self._zplanes)
return [
TileKey(round=selector[Axes.ROUND], ch=selector[Axes.CH], zplane=selector[Axes.ZPLANE])
for selector in ordered_iterator(axes_sizes)
TileKey(round=round_label, ch=ch_label, zplane=zplane_label)
for round_label in self._rounds
for ch_label in self._chs
for zplane_label in self._zplanes
]

@property
def group_by(self) -> Set[Axes]:
"""Returns the axes to group by when we load the data."""
return self._group_by

@property
def tile_shape(self) -> Mapping[Axes, int]:
return self._tile_shape
Expand Down
7 changes: 6 additions & 1 deletion starfish/core/imagestack/parser/tileset/_parser.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
This module parses and retains the extras metadata attached to TileSet extras.
"""
from typing import Collection, Mapping, MutableMapping, Tuple
from typing import Collection, Mapping, MutableMapping, Set, Tuple

import numpy as np
from slicedimage import Tile, TileSet
Expand Down Expand Up @@ -103,6 +103,11 @@ def keys(self) -> Collection[TileKey]:
"""Returns a Collection of the TileKey's for all the tiles."""
return self.tiles.keys()

@property
def group_by(self) -> Set[Axes]:
"""Returns the axes to group by when we load the data."""
return set()

@property
def tile_shape(self) -> Mapping[Axes, int]:
return self._tile_shape
Expand Down
13 changes: 13 additions & 0 deletions starfish/core/imagestack/test/test_from_tilefetcher.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Mapping, Union

import numpy as np
import pytest

from starfish.core.experiment.builder.builder import tile_fetcher_factory
from starfish.core.experiment.builder.test.factories.unique_tiles import unique_data, UniqueTiles
Expand All @@ -18,7 +19,18 @@ def coordinates(self) -> Mapping[Union[str, Coordinates], CoordinateValue]:
}


@pytest.mark.parametrize(
"group_by",
[
None,
set((Axes.ROUND,)),
set((Axes.ROUND, Axes.CH)),
set((Axes.ROUND, Axes.CH, Axes.ZPLANE)),
set((Axes.ZPLANE,)),
]
)
def test_from_tilefetcher(
group_by,
rounds=(0, 1, 2, 3),
chs=(0, 1, 3),
zplanes=(0, 1),
Expand All @@ -42,6 +54,7 @@ def test_from_tilefetcher(
rounds=rounds,
chs=chs,
zplanes=zplanes,
group_by=group_by,
)

assert stack.shape == {
Expand Down

0 comments on commit bc60209

Please sign in to comment.