Skip to content

Commit

Permalink
clean up max projection
Browse files Browse the repository at this point in the history
Remove ImageStack.max_proj. Steer users towards using ImageStack.reduce.

Test plan: travis

Addresses comment in https://github.com/spacetx/starfish/pull/1342/files#r288261378
Fixes #220
  • Loading branch information
Tony Tung committed Oct 10, 2019
1 parent a1fbdfe commit b0e553b
Show file tree
Hide file tree
Showing 10 changed files with 46 additions and 49 deletions.
2 changes: 1 addition & 1 deletion starfish/core/_display.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def display(

if stack is not None:
if project_axes is not None:
stack = stack.max_proj(*project_axes)
stack = stack.reduce(project_axes, func="max")

viewer.add_image(stack.xarray.values,
rgb=False,
Expand Down
7 changes: 7 additions & 0 deletions starfish/core/errors.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
class DeprecatedAPIError(Exception):
"""
Raised when using an API that has been deprecated.
"""
pass


class DataFormatWarning(Warning):
"""
Warnings given by starfish when the data is not formatted as expected, though not fatally.
Expand Down
33 changes: 23 additions & 10 deletions starfish/core/image/Filter/max_proj.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import warnings
from typing import Iterable, Optional, Union
from typing import Iterable, MutableMapping, Optional, Sequence, Union

import numpy as np

from starfish.core.imagestack.imagestack import ImageStack
from starfish.core.types import Axes
from starfish.core.types import Axes, Coordinates, Number
from ._base import FilterAlgorithm


Expand All @@ -15,7 +17,7 @@ class MaxProject(FilterAlgorithm):
Parameters
----------
dims : Axes
dims : Iterable[Union[Axes, str]]
one or more Axes to project over
See Also
Expand All @@ -25,21 +27,18 @@ class MaxProject(FilterAlgorithm):
"""

def __init__(self, dims: Iterable[Union[Axes, str]]) -> None:

warnings.warn(
"Filter.MaxProject is being deprecated in favor of Filter.Reduce(func='max')",
DeprecationWarning,
)
self.dims = dims
self.dims = set(Axes(dim) for dim in dims)

_DEFAULT_TESTING_PARAMETERS = {"dims": 'r'}

def run(
self,
stack: ImageStack,
in_place: bool = False,
verbose: bool = False,
n_processes: Optional[int] = None,
*args,
) -> Optional[ImageStack]:
"""Perform filtering of an image stack
Expand All @@ -58,8 +57,22 @@ def run(
Returns
-------
ImageStack :
If in-place is False, return the results of filter as a new stack. Otherwise return the
original stack.
The max projection of an image across one or more axis.
"""
return stack.max_proj(*tuple(Axes(dim) for dim in self.dims))
max_projection = stack.xarray.max([dim.value for dim in self.dims])
max_projection = max_projection.expand_dims(tuple(dim.value for dim in self.dims))
max_projection = max_projection.transpose(*stack.xarray.dims)
physical_coords: MutableMapping[Coordinates, Sequence[Number]] = {}
for axis, coord in (
(Axes.X, Coordinates.X),
(Axes.Y, Coordinates.Y),
(Axes.ZPLANE, Coordinates.Z)):
if axis in self.dims:
# this axis was projected out of existence.
assert coord.value not in max_projection.coords
physical_coords[coord] = [np.average(stack.xarray.coords[coord.value])]
else:
physical_coords[coord] = max_projection.coords[coord.value]
max_proj_stack = ImageStack.from_numpy(max_projection.values, coordinates=physical_coords)
return max_proj_stack
2 changes: 0 additions & 2 deletions starfish/core/image/Filter/test/test_api_contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,7 @@
- constructor accepts 3d and 2d data, where default is 2d
- values emitted by a filter are floats between 0 and 1 (inclusive)
- exposes a `run`() method
- run accepts an in-place parameter which defaults to True
- run always returns an ImageStack (if in-place, returns a reference to the modified input data)
- run accepts an `n_processes` parameter which determines
- run accepts a `verbose` parameter, which triggers tqdm to print progress
To add a new filter, simply add default
Expand Down
4 changes: 2 additions & 2 deletions starfish/core/image/Segment/watershed.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def run(
"""

# create a 'stain' for segmentation
mp = primary_images.max_proj(Axes.CH, Axes.ZPLANE)
mp = primary_images.reduce({Axes.CH, Axes.ZPLANE}, func="max")
mp_numpy = mp._squeezed_numpy(Axes.CH, Axes.ZPLANE)
stain = np.mean(mp_numpy, axis=0)
stain = stain / stain.max()
Expand All @@ -86,7 +86,7 @@ def run(
disk_size_markers = None
disk_size_mask = None

nuclei_mp = nuclei.max_proj(Axes.ROUND, Axes.CH, Axes.ZPLANE)
nuclei_mp = nuclei.reduce({Axes.ROUND, Axes.CH, Axes.ZPLANE}, func="max")
nuclei__mp_numpy = nuclei_mp._squeezed_numpy(Axes.ROUND, Axes.CH, Axes.ZPLANE)
self._segmentation_instance = _WatershedSegmenter(nuclei__mp_numpy, stain)
label_image = self._segmentation_instance.segment(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np

from starfish import data
from starfish.core.image import Filter
from starfish.core.image._registration.ApplyTransform.warp import Warp
from starfish.core.image._registration.LearnTransform.translation import Translation
from starfish.core.types import Axes
Expand Down Expand Up @@ -35,7 +36,7 @@ def test_calculate_translation_transforms_and_apply():
reference_stack = exp.fov().get_image('dots')
translation = Translation(reference_stack=reference_stack, axes=Axes.ROUND)
# Calculate max_proj accrss
mp = stack.max_proj(Axes.CH, Axes.ZPLANE)
mp = Filter.Reduce((Axes.CH, Axes.ZPLANE)).run(stack)
transform_list = translation.run(mp)
apply_transform = Warp()
warped_stack = apply_transform.run(stack=stack, transforms_list=transform_list)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np

from starfish import data
from starfish.core.image import Filter
from starfish.core.image._registration.LearnTransform.translation import Translation
from starfish.core.types import Axes

Expand All @@ -26,7 +27,7 @@ def test_learn_transforms_translation():
reference_stack = exp.fov().get_image('dots')
translation = Translation(reference_stack=reference_stack, axes=Axes.ROUND)
# Calculate max_proj accrss CH/Z
stack = stack.max_proj(Axes.CH, Axes.ZPLANE)
stack = Filter.Reduce((Axes.CH, Axes.ZPLANE)).run(stack)
transform_list = translation.run(stack)
# assert there's a transofrmation object for each round
assert len(transform_list.transforms) == stack.num_rounds
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np

from starfish import data
from starfish.core.image import Filter
from starfish.core.image._registration.LearnTransform.translation import Translation
from starfish.core.image._registration.transforms_list import TransformsList
from starfish.core.types import Axes, TransformType
Expand All @@ -17,7 +18,7 @@ def test_export_import_transforms_object():
reference_stack = exp.fov().get_image('dots')
translation = Translation(reference_stack=reference_stack, axes=Axes.ROUND)
# Calculate max_proj accrss CH/Z
stack = stack.max_proj(Axes.CH, Axes.ZPLANE)
stack = Filter.Reduce((Axes.CH, Axes.ZPLANE)).run(stack)
transform_list = translation.run(stack)
_, filename = tempfile.mkstemp()
# save to tempfile and import
Expand Down
35 changes: 5 additions & 30 deletions starfish/core/imagestack/imagestack.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
from tqdm import tqdm

from starfish.core.config import StarfishConfig
from starfish.core.errors import DataFormatWarning
from starfish.core.errors import DataFormatWarning, DeprecatedAPIError
from starfish.core.imagestack import indexing_utils
from starfish.core.imagestack.parser import TileCollectionData, TileKey
from starfish.core.imagestack.parser.crop import CropParameters, CroppedTileCollectionData
Expand Down Expand Up @@ -1143,36 +1143,11 @@ def tile_opener(tileset_path: Path, tile, ext):
tile_format=tile_format)

def max_proj(self, *dims: Axes) -> "ImageStack":
"""return a max projection over one or more axis of the image tensor
Parameters
----------
dims : Axes
one or more axes to project over
Returns
-------
np.ndarray :
max projection
"""
self._ensure_data_loaded()
max_projection = self._data.max([dim.value for dim in dims])
max_projection = max_projection.expand_dims(tuple(dim.value for dim in dims))
max_projection = max_projection.transpose(*self.xarray.dims)
physical_coords: MutableMapping[Coordinates, Sequence[Number]] = {}
for axis, coord in (
(Axes.X, Coordinates.X),
(Axes.Y, Coordinates.Y),
(Axes.ZPLANE, Coordinates.Z)):
if axis in dims:
# this axis was projected out of existence.
assert coord.value not in max_projection.coords
physical_coords[coord] = [np.average(self._data.coords[coord.value])]
else:
physical_coords[coord] = max_projection.coords[coord.value]
max_proj_stack = ImageStack.from_numpy(max_projection.values, coordinates=physical_coords)
return max_proj_stack
This method is deprecated. Please ``ImageStack.reduce(axes, func="max")`` to do max
projection operations.
"""
raise DeprecatedAPIError("Please Filter.MaxProject to do max projection operations.")

def _squeezed_numpy(self, *dims: Axes):
"""return this ImageStack's data as a squeezed numpy array"""
Expand Down
3 changes: 2 additions & 1 deletion starfish/core/spots/DetectSpots/detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,8 @@ def detect_spots(data_stack: ImageStack,

if reference_image is not None:
if reference_image_max_projection_axes is not None:
reference_image = reference_image.max_proj(*reference_image_max_projection_axes)
reference_image = reference_image.reduce(
reference_image_max_projection_axes, func="max")
data_image = reference_image._squeezed_numpy(*reference_image_max_projection_axes)
else:
data_image = reference_image.xarray
Expand Down

0 comments on commit b0e553b

Please sign in to comment.