Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

clean up max projection #1379

Merged
merged 1 commit into from
Oct 10, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion starfish/core/_display.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def display(

if stack is not None:
if project_axes is not None:
stack = stack.max_proj(*project_axes)
stack = stack.reduce(project_axes, func="max")

viewer.add_image(stack.xarray.values,
rgb=False,
Expand Down
7 changes: 7 additions & 0 deletions starfish/core/errors.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
class DeprecatedAPIError(Exception):
"""
Raised when using an API that has been deprecated.
"""
pass


class DataFormatWarning(Warning):
"""
Warnings given by starfish when the data is not formatted as expected, though not fatally.
Expand Down
33 changes: 23 additions & 10 deletions starfish/core/image/Filter/max_proj.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import warnings
from typing import Iterable, Optional, Union
from typing import Iterable, MutableMapping, Optional, Sequence, Union

import numpy as np

from starfish.core.imagestack.imagestack import ImageStack
from starfish.core.types import Axes
from starfish.core.types import Axes, Coordinates, Number
from ._base import FilterAlgorithm


Expand All @@ -15,7 +17,7 @@ class MaxProject(FilterAlgorithm):

Parameters
----------
dims : Axes
dims : Iterable[Union[Axes, str]]
one or more Axes to project over

See Also
Expand All @@ -25,21 +27,18 @@ class MaxProject(FilterAlgorithm):
"""

def __init__(self, dims: Iterable[Union[Axes, str]]) -> None:

warnings.warn(
"Filter.MaxProject is being deprecated in favor of Filter.Reduce(func='max')",
DeprecationWarning,
)
self.dims = dims
self.dims = set(Axes(dim) for dim in dims)

_DEFAULT_TESTING_PARAMETERS = {"dims": 'r'}

def run(
self,
stack: ImageStack,
in_place: bool = False,
verbose: bool = False,
n_processes: Optional[int] = None,
*args,
) -> Optional[ImageStack]:
"""Perform filtering of an image stack
Expand All @@ -58,8 +57,22 @@ def run(
Returns
-------
ImageStack :
If in-place is False, return the results of filter as a new stack. Otherwise return the
original stack.
The max projection of an image across one or more axis.

"""
return stack.max_proj(*tuple(Axes(dim) for dim in self.dims))
max_projection = stack.xarray.max([dim.value for dim in self.dims])
max_projection = max_projection.expand_dims(tuple(dim.value for dim in self.dims))
max_projection = max_projection.transpose(*stack.xarray.dims)
physical_coords: MutableMapping[Coordinates, Sequence[Number]] = {}
for axis, coord in (
(Axes.X, Coordinates.X),
(Axes.Y, Coordinates.Y),
(Axes.ZPLANE, Coordinates.Z)):
if axis in self.dims:
# this axis was projected out of existence.
assert coord.value not in max_projection.coords
physical_coords[coord] = [np.average(stack.xarray.coords[coord.value])]
else:
physical_coords[coord] = max_projection.coords[coord.value]
max_proj_stack = ImageStack.from_numpy(max_projection.values, coordinates=physical_coords)
return max_proj_stack
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wait does it still make sense to have the max_proj filter is stack.reduce('max') does the same thing?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

max_proj filter is deprecated. Reduce filter still exists. stack.reduce is an alias for that filter.

2 changes: 0 additions & 2 deletions starfish/core/image/Filter/test/test_api_contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,7 @@
- constructor accepts 3d and 2d data, where default is 2d
- values emitted by a filter are floats between 0 and 1 (inclusive)
- exposes a `run`() method
- run accepts an in-place parameter which defaults to True
- run always returns an ImageStack (if in-place, returns a reference to the modified input data)
- run accepts an `n_processes` parameter which determines
- run accepts a `verbose` parameter, which triggers tqdm to print progress

To add a new filter, simply add default
Expand Down
4 changes: 2 additions & 2 deletions starfish/core/image/Segment/watershed.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def run(
"""

# create a 'stain' for segmentation
mp = primary_images.max_proj(Axes.CH, Axes.ZPLANE)
mp = primary_images.reduce({Axes.CH, Axes.ZPLANE}, func="max")
mp_numpy = mp._squeezed_numpy(Axes.CH, Axes.ZPLANE)
stain = np.mean(mp_numpy, axis=0)
stain = stain / stain.max()
Expand All @@ -86,7 +86,7 @@ def run(
disk_size_markers = None
disk_size_mask = None

nuclei_mp = nuclei.max_proj(Axes.ROUND, Axes.CH, Axes.ZPLANE)
nuclei_mp = nuclei.reduce({Axes.ROUND, Axes.CH, Axes.ZPLANE}, func="max")
nuclei__mp_numpy = nuclei_mp._squeezed_numpy(Axes.ROUND, Axes.CH, Axes.ZPLANE)
self._segmentation_instance = _WatershedSegmenter(nuclei__mp_numpy, stain)
label_image = self._segmentation_instance.segment(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np

from starfish import data
from starfish.core.image import Filter
from starfish.core.image._registration.ApplyTransform.warp import Warp
from starfish.core.image._registration.LearnTransform.translation import Translation
from starfish.core.types import Axes
Expand Down Expand Up @@ -35,7 +36,7 @@ def test_calculate_translation_transforms_and_apply():
reference_stack = exp.fov().get_image('dots')
translation = Translation(reference_stack=reference_stack, axes=Axes.ROUND)
# Calculate max_proj accrss
mp = stack.max_proj(Axes.CH, Axes.ZPLANE)
mp = Filter.Reduce((Axes.CH, Axes.ZPLANE)).run(stack)
transform_list = translation.run(mp)
apply_transform = Warp()
warped_stack = apply_transform.run(stack=stack, transforms_list=transform_list)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np

from starfish import data
from starfish.core.image import Filter
from starfish.core.image._registration.LearnTransform.translation import Translation
from starfish.core.types import Axes

Expand All @@ -26,7 +27,7 @@ def test_learn_transforms_translation():
reference_stack = exp.fov().get_image('dots')
translation = Translation(reference_stack=reference_stack, axes=Axes.ROUND)
# Calculate max_proj accrss CH/Z
stack = stack.max_proj(Axes.CH, Axes.ZPLANE)
stack = Filter.Reduce((Axes.CH, Axes.ZPLANE)).run(stack)
transform_list = translation.run(stack)
# assert there's a transofrmation object for each round
assert len(transform_list.transforms) == stack.num_rounds
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np

from starfish import data
from starfish.core.image import Filter
from starfish.core.image._registration.LearnTransform.translation import Translation
from starfish.core.image._registration.transforms_list import TransformsList
from starfish.core.types import Axes, TransformType
Expand All @@ -17,7 +18,7 @@ def test_export_import_transforms_object():
reference_stack = exp.fov().get_image('dots')
translation = Translation(reference_stack=reference_stack, axes=Axes.ROUND)
# Calculate max_proj accrss CH/Z
stack = stack.max_proj(Axes.CH, Axes.ZPLANE)
stack = Filter.Reduce((Axes.CH, Axes.ZPLANE)).run(stack)
transform_list = translation.run(stack)
_, filename = tempfile.mkstemp()
# save to tempfile and import
Expand Down
35 changes: 5 additions & 30 deletions starfish/core/imagestack/imagestack.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
from tqdm import tqdm

from starfish.core.config import StarfishConfig
from starfish.core.errors import DataFormatWarning
from starfish.core.errors import DataFormatWarning, DeprecatedAPIError
from starfish.core.imagestack import indexing_utils
from starfish.core.imagestack.parser import TileCollectionData, TileKey
from starfish.core.imagestack.parser.crop import CropParameters, CroppedTileCollectionData
Expand Down Expand Up @@ -1143,36 +1143,11 @@ def tile_opener(tileset_path: Path, tile, ext):
tile_format=tile_format)

def max_proj(self, *dims: Axes) -> "ImageStack":
"""return a max projection over one or more axis of the image tensor

Parameters
----------
dims : Axes
one or more axes to project over

Returns
-------
np.ndarray :
max projection

"""
self._ensure_data_loaded()
max_projection = self._data.max([dim.value for dim in dims])
max_projection = max_projection.expand_dims(tuple(dim.value for dim in dims))
max_projection = max_projection.transpose(*self.xarray.dims)
physical_coords: MutableMapping[Coordinates, Sequence[Number]] = {}
for axis, coord in (
(Axes.X, Coordinates.X),
(Axes.Y, Coordinates.Y),
(Axes.ZPLANE, Coordinates.Z)):
if axis in dims:
# this axis was projected out of existence.
assert coord.value not in max_projection.coords
physical_coords[coord] = [np.average(self._data.coords[coord.value])]
else:
physical_coords[coord] = max_projection.coords[coord.value]
max_proj_stack = ImageStack.from_numpy(max_projection.values, coordinates=physical_coords)
return max_proj_stack
This method is deprecated. Please ``ImageStack.reduce(axes, func="max")`` to do max
projection operations.
"""
raise DeprecatedAPIError("Please Filter.MaxProject to do max projection operations.")

def _squeezed_numpy(self, *dims: Axes):
"""return this ImageStack's data as a squeezed numpy array"""
Expand Down
3 changes: 2 additions & 1 deletion starfish/core/spots/DetectSpots/detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,8 @@ def detect_spots(data_stack: ImageStack,

if reference_image is not None:
if reference_image_max_projection_axes is not None:
reference_image = reference_image.max_proj(*reference_image_max_projection_axes)
reference_image = reference_image.reduce(
reference_image_max_projection_axes, func="max")
data_image = reference_image._squeezed_numpy(*reference_image_max_projection_axes)
else:
data_image = reference_image.xarray
Expand Down