Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add group_by for tilefetcher-based ImageStack construction #1796

Merged
merged 1 commit into from
Feb 7, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/data_loading/plot_tilefetcher_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,5 +103,6 @@ def get_tile(
rounds=range(num_r),
chs=range(num_c),
zplanes=range(num_z),
group_by=(Axes.ROUND, Axes.CH),
)
print(repr(stack))
43 changes: 22 additions & 21 deletions starfish/core/imagestack/imagestack.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from typing import (
Any,
Callable,
Collection,
Hashable,
Iterable,
Iterator,
Expand Down Expand Up @@ -91,7 +92,7 @@ class ImageStack:
the shape of the image tensor by categorical index (channels, imaging rounds, z-layers)
"""

def __init__(self, data: xr.DataArray, tile_data: Optional[TileCollectionData]=None):
def __init__(self, data: xr.DataArray, tile_data: TileCollectionData):
self._data = data
self._data_loaded = False
self._tile_data = tile_data
Expand Down Expand Up @@ -214,9 +215,19 @@ def load_by_selector(selector):

return tile_dtype

with ThreadPoolExecutor() as tpe:
# gather all the data types of the tiles to ensure that they are compatible.
tile_dtypes = set(tpe.map(load_by_selector, all_selectors))
if len(self._tile_data.group_by) == 0:
with ThreadPoolExecutor() as tpe:
# gather all the data types of the tiles to ensure that they are compatible.
tile_dtypes = set(tpe.map(load_by_selector, all_selectors))
else:
tile_dtypes = set()
group_by_selectors = list(self._iter_axes(self._tile_data.group_by))
non_group_by_selectors = list(self._iter_axes(
{Axes.ROUND, Axes.CH, Axes.ZPLANE} - self._tile_data.group_by))
for group_by_selector in group_by_selectors:
for non_group_by_selector in non_group_by_selectors:
tile_dtypes.add(load_by_selector(
{**group_by_selector, **non_group_by_selector}))
pbar.close()

tile_dtype_kinds = set(tile_dtype.kind for tile_dtype in tile_dtypes)
Expand Down Expand Up @@ -269,7 +280,7 @@ def from_tilefetcher(
rounds: Sequence[int],
chs: Sequence[int],
zplanes: Sequence[int],
axes_order: Optional[Sequence[Axes]] = None,
group_by: Optional[Collection[Axes]] = None,
crop_parameters: Optional[CropParameters]=None,
) -> "ImageStack":
"""
Expand All @@ -289,21 +300,11 @@ def from_tilefetcher(
The channels to include in this ImageStack.
zplanes : Sequence[int]
The zplanes to include in this ImageStack.
axes_order : Optional[Sequence[Axes]]
Ordering for which axes vary, in order of the slowest changing axis to the fastest. For
instance, if the order is (ROUND, Z, CH) and each dimension has size 2, then the
sequence is:

(ROUND=0, CH=0, Z=0)
(ROUND=0, CH=1, Z=0)
(ROUND=0, CH=0, Z=1)
(ROUND=0, CH=1, Z=1)
(ROUND=1, CH=0, Z=0)
(ROUND=1, CH=1, Z=0)
(ROUND=1, CH=0, Z=1)
(ROUND=1, CH=1, Z=1)

(default = (Axes.Z, Axes.ROUND, Axes.CH))
group_by : Optional[Set[Axes]]
Axes to load the data by. If an axis is present in this list, all the data for a given
value along that axis will be loaded concurrently. For example, if group_by is
(Axes.ROUND, Axes.CH), then all the data for ROUND=2, CH=1 will be loaded before we
progress to ROUND=3, CH=1.
crop_parameters : Optional[CropParameters]
If cropping of the data is desired, it should be specified here.

Expand All @@ -315,7 +316,7 @@ def from_tilefetcher(
from starfish.core.imagestack.parser.tilefetcher import TileFetcherData

tile_data: TileCollectionData = TileFetcherData(
tilefetcher, tile_shape, fov, rounds, chs, zplanes, axes_order)
tilefetcher, tile_shape, fov, rounds, chs, zplanes, group_by)
if crop_parameters is not None:
tile_data = CroppedTileCollectionData(tile_data, crop_parameters)
return ImageStack.from_tile_collection_data(tile_data)
Expand Down
7 changes: 6 additions & 1 deletion starfish/core/imagestack/parser/_tiledata.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Collection, Mapping
from typing import Collection, Mapping, Set

import numpy as np

Expand Down Expand Up @@ -52,6 +52,11 @@ def tile_shape(self) -> Mapping[Axes, int]:
"""Returns the shape of a tile."""
raise NotImplementedError()

@property
def group_by(self) -> Set[Axes]:
"""Returns the axes to group by when we load the data."""
raise NotImplementedError()

@property
def extras(self) -> dict:
"""Returns the extras metadata for the TileSet."""
Expand Down
7 changes: 6 additions & 1 deletion starfish/core/imagestack/parser/crop.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections import OrderedDict
from typing import Collection, List, Mapping, MutableSequence, Optional, Tuple, Union
from typing import Collection, List, Mapping, MutableSequence, Optional, Set, Tuple, Union

import numpy as np
from slicedimage import Tile, TileSet
Expand Down Expand Up @@ -280,6 +280,11 @@ def __getitem__(self, tilekey: TileKey) -> dict:
def keys(self) -> Collection[TileKey]:
return self.crop_parameters.filter_tilekeys(self.backing_tile_collection_data.keys())

@property
def group_by(self) -> Set[Axes]:
"""Returns the axes to group by when we load the data."""
return self.backing_tile_collection_data.group_by

@property
def tile_shape(self) -> Mapping[Axes, int]:
return self.crop_parameters.crop_shape(self.backing_tile_collection_data.tile_shape)
Expand Down
6 changes: 6 additions & 0 deletions starfish/core/imagestack/parser/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
MutableSequence,
Optional,
Sequence,
Set,
)

import numpy as np
Expand Down Expand Up @@ -100,6 +101,11 @@ def keys(self) -> Collection[TileKey]:

return keys

@property
def group_by(self) -> Set[Axes]:
"""Returns the axes to group by when we load the data."""
return set()

@property
def tile_shape(self) -> Mapping[Axes, int]:
return {Axes.Y: self.data.shape[-2], Axes.X: self.data.shape[-1]}
Expand Down
22 changes: 12 additions & 10 deletions starfish/core/imagestack/parser/tilefetcher/_parser.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
"""
This module wraps a TileFetcher to provide the data to instantiate an ImageStack.
"""
from typing import Collection, Mapping, MutableMapping, Optional, Sequence
from typing import Collection, Mapping, MutableMapping, Optional, Sequence, Set

import numpy as np

from starfish.core.experiment.builder.orderediterator import join_axes_labels, ordered_iterator
from starfish.core.experiment.builder.providers import FetchedTile, TileFetcher
from starfish.core.imagestack.parser import TileCollectionData, TileData, TileKey
from starfish.core.imagestack.physical_coordinates import _get_physical_coordinates_of_z_plane
Expand Down Expand Up @@ -78,31 +77,34 @@ def __init__(
rounds: Sequence[int],
chs: Sequence[int],
zplanes: Sequence[int],
axes_order: Optional[Sequence[Axes]] = None,
group_by: Optional[Collection[Axes]] = None,
) -> None:
self._tile_fetcher = tile_fetcher
self._tile_shape = tile_shape
self._fov = fov
self._rounds = rounds
self._chs = chs
self._zplanes = zplanes
if axes_order is None:
axes_order = Axes.ZPLANE, Axes.ROUND, Axes.CH
self._axes_order = axes_order
self._group_by = set(group_by) if group_by is not None else set()

def __getitem__(self, tilekey: TileKey) -> dict:
"""Returns the extras metadata for a given tile, addressed by its TileKey"""
return {}

def keys(self) -> Collection[TileKey]:
"""Returns a Collection of the TileKey's for all the tiles."""
axes_sizes = join_axes_labels(
self._axes_order, rounds=self._rounds, chs=self._chs, zplanes=self._zplanes)
return [
TileKey(round=selector[Axes.ROUND], ch=selector[Axes.CH], zplane=selector[Axes.ZPLANE])
for selector in ordered_iterator(axes_sizes)
TileKey(round=round_label, ch=ch_label, zplane=zplane_label)
for round_label in self._rounds
for ch_label in self._chs
for zplane_label in self._zplanes
]

@property
def group_by(self) -> Set[Axes]:
"""Returns the axes to group by when we load the data."""
return self._group_by

@property
def tile_shape(self) -> Mapping[Axes, int]:
return self._tile_shape
Expand Down
7 changes: 6 additions & 1 deletion starfish/core/imagestack/parser/tileset/_parser.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
This module parses and retains the extras metadata attached to TileSet extras.
"""
from typing import Collection, Mapping, MutableMapping, Tuple
from typing import Collection, Mapping, MutableMapping, Set, Tuple

import numpy as np
from slicedimage import Tile, TileSet
Expand Down Expand Up @@ -103,6 +103,11 @@ def keys(self) -> Collection[TileKey]:
"""Returns a Collection of the TileKey's for all the tiles."""
return self.tiles.keys()

@property
def group_by(self) -> Set[Axes]:
"""Returns the axes to group by when we load the data."""
return set()

@property
def tile_shape(self) -> Mapping[Axes, int]:
return self._tile_shape
Expand Down
13 changes: 13 additions & 0 deletions starfish/core/imagestack/test/test_from_tilefetcher.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Mapping, Union

import numpy as np
import pytest

from starfish.core.experiment.builder.builder import tile_fetcher_factory
from starfish.core.experiment.builder.test.factories.unique_tiles import unique_data, UniqueTiles
Expand All @@ -18,7 +19,18 @@ def coordinates(self) -> Mapping[Union[str, Coordinates], CoordinateValue]:
}


@pytest.mark.parametrize(
"group_by",
[
None,
set((Axes.ROUND,)),
set((Axes.ROUND, Axes.CH)),
set((Axes.ROUND, Axes.CH, Axes.ZPLANE)),
set((Axes.ZPLANE,)),
]
)
def test_from_tilefetcher(
group_by,
rounds=(0, 1, 2, 3),
chs=(0, 1, 3),
zplanes=(0, 1),
Expand All @@ -42,6 +54,7 @@ def test_from_tilefetcher(
rounds=rounds,
chs=chs,
zplanes=zplanes,
group_by=group_by,
)

assert stack.shape == {
Expand Down