Skip to content

Commit

Permalink
RFC: clean up max projection
Browse files Browse the repository at this point in the history
1. Remove ImageStack.max_proj.  Steer users towards using Filter.MaxProject.
2. Move the core functionality of max_proj to _max_proj() as it is used commonly in internal code.
3. Refactor ISS notebook to use Filter.MaxProject.

#2 is somewhat debatable.  I feel there is a good case that all internal code that wants to run max_projection should already have been reduced.

Test plan: travis
  • Loading branch information
Tony Tung committed Jun 21, 2019
1 parent 83a6202 commit ea9ae4d
Show file tree
Hide file tree
Showing 9 changed files with 510 additions and 106 deletions.
534 changes: 467 additions & 67 deletions notebooks/ISS.ipynb

Large diffs are not rendered by default.

26 changes: 13 additions & 13 deletions notebooks/py/ISS.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import pprint

from starfish import data, FieldOfView
from starfish.image import Filter
from starfish.types import Features, Axes
from starfish.util.plot import imshow_plane
# EPY: END code
Expand Down Expand Up @@ -80,13 +81,19 @@
imshow_plane(single_plane, title="Round: 0, Channel: 0")
# EPY: END code

# EPY: START code
# We use a handful of max projections. Initialize them.
rcz_max_projection = Filter.MaxProject(dims={Axes.ROUND, Axes.CH, Axes.ZPLANE})
cz_max_projection = Filter.MaxProject(dims={Axes.CH, Axes.ZPLANE})
# EPY: END code

# EPY: START markdown
#'dots' is a general stain for all possible transcripts. This image should correspond to the maximum projcection of all color channels within a single imaging round. This auxiliary image is useful for registering images from multiple imaging rounds to this reference image. We'll see an example of this further on in the notebook
# EPY: END markdown

# EPY: START code
dots = fov.get_image("dots")
dots_single_plane = dots.max_proj(Axes.ROUND, Axes.CH, Axes.ZPLANE)
dots_single_plane = rcz_max_projection.run(dots)
imshow_plane(dots_single_plane, title="Anchor channel, all RNA molecules")
# EPY: END code

Expand All @@ -96,7 +103,7 @@

# EPY: START code
nuclei = fov.get_image("nuclei")
nuclei_single_plane = nuclei.max_proj(Axes.ROUND, Axes.CH, Axes.ZPLANE)
nuclei_single_plane = rcz_max_projection.run(nuclei)
imshow_plane(nuclei_single_plane, title="Nuclei (DAPI) channel")
# EPY: END code

Expand All @@ -107,8 +114,6 @@
# EPY: END markdown

# EPY: START code
from starfish.image import Filter

# filter raw data
masking_radius = 15
filt = Filter.WhiteTophat(masking_radius, is_volume=False)
Expand Down Expand Up @@ -145,8 +150,9 @@
# EPY: START code
from starfish.image import ApplyTransform, LearnTransform

per_round_max_projected = cz_max_projection.run(imgs)
learn_translation = LearnTransform.Translation(reference_stack=dots, axes=Axes.ROUND, upsampling=1000)
transforms_list = learn_translation.run(imgs.max_proj(Axes.CH, Axes.ZPLANE))
transforms_list = learn_translation.run(per_round_max_projected)
warp = ApplyTransform.Warp()
registered_imgs = warp.run(filtered_imgs, transforms_list=transforms_list, in_place=False, verbose=True)
# EPY: END code
Expand Down Expand Up @@ -211,12 +217,6 @@
stain_thresh = .22 # binary mask for overall cells // binarization of stain
min_dist = 57

registered_mp = registered_imgs.max_proj(Axes.CH, Axes.ZPLANE).xarray.squeeze()
stain = np.mean(registered_mp, axis=0)
stain = stain/stain.max()
nuclei = nuclei.max_proj(Axes.ROUND, Axes.CH, Axes.ZPLANE)


seg = Segment.Watershed(
nuclei_threshold=dapi_thresh,
input_threshold=stain_thresh,
Expand Down Expand Up @@ -255,10 +255,10 @@
GENE2 = 'VIM'

rgb = np.zeros(registered_imgs.tile_shape + (3,))
nuclei_mp = nuclei.max_proj(Axes.ROUND, Axes.CH, Axes.ZPLANE)
nuclei_mp = rcz_max_projection.run(nuclei)
nuclei_numpy = nuclei_mp._squeezed_numpy(Axes.ROUND, Axes.CH, Axes.ZPLANE)
rgb[:,:,0] = nuclei_numpy
dots_mp = dots.max_proj(Axes.ROUND, Axes.CH, Axes.ZPLANE)
dots_mp = rcz_max_projection.run(dots)
dots_mp_numpy = dots_mp._squeezed_numpy(Axes.ROUND, Axes.CH, Axes.ZPLANE)
rgb[:,:,1] = dots_mp_numpy
do = rgb2gray(rgb)
Expand Down
7 changes: 7 additions & 0 deletions starfish/core/errors.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
class DeprecatedAPIError(Exception):
"""
Raised when using an API that has been deprecated.
"""
pass


class DataFormatWarning(Warning):
"""
Warnings given by starfish when the data is not formatted as expected, though not fatally.
Expand Down
21 changes: 7 additions & 14 deletions starfish/core/image/_filter/max_proj.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import warnings
from typing import Iterable, Optional, Union
from typing import Iterable

from starfish.core.imagestack.imagestack import ImageStack
from starfish.core.types import Axes
Expand All @@ -13,7 +13,7 @@ class MaxProject(FilterAlgorithmBase):
Parameters
----------
dims : Axes
dims : Iterable[Axes]
one or more Axes to project over
See Also
Expand All @@ -22,8 +22,7 @@ class MaxProject(FilterAlgorithmBase):
"""

def __init__(self, dims: Iterable[Union[Axes, str]]) -> None:

def __init__(self, dims: Iterable[Axes]) -> None:
warnings.warn(
"Filter.MaxProject is being deprecated in favor of Filter.Reduce(func='max')",
DeprecationWarning,
Expand All @@ -35,9 +34,7 @@ def __init__(self, dims: Iterable[Union[Axes, str]]) -> None:
def run(
self,
stack: ImageStack,
in_place: bool = False,
verbose: bool = False,
n_processes: Optional[int] = None,
*args,
) -> ImageStack:
"""Perform filtering of an image stack
Expand All @@ -46,21 +43,16 @@ def run(
----------
stack : ImageStack
Stack to be filtered.
in_place : bool
if True, process ImageStack in-place, otherwise return a new stack
verbose : bool
if True, report on filtering progress (default = False)
n_processes : Optional[int]
Number of parallel processes to devote to calculating the filter
Returns
-------
ImageStack :
If in-place is False, return the results of filter as a new stack. Otherwise return the
original stack.
The max projection of an image across one or more axis.
"""
return stack.max_proj(*tuple(Axes(dim) for dim in self.dims))
return stack._max_proj(*self.dims)

@staticmethod
@click.command("MaxProject")
Expand All @@ -75,4 +67,5 @@ def run(
"--dims r --dims c")
@click.pass_context
def _cli(ctx, dims):
ctx.obj["component"]._cli_run(ctx, MaxProject(dims))
formatted_dims = [Axes(dim) for dim in dims]
ctx.obj["component"]._cli_run(ctx, MaxProject(formatted_dims))
2 changes: 0 additions & 2 deletions starfish/core/image/_filter/test/test_api_contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,7 @@
- constructor accepts 3d and 2d data, where default is 2d
- values emitted by a filter are floats between 0 and 1 (inclusive)
- exposes a `run`() method
- run accepts an in-place parameter which defaults to True
- run always returns an ImageStack (if in-place, returns a reference to the modified input data)
- run accepts an `n_processes` parameter which determines
- run accepts a `verbose` parameter, which triggers tqdm to print progress
To add a new filter, simply add default
Expand Down
4 changes: 2 additions & 2 deletions starfish/core/image/_segment/watershed.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def run(
"""

# create a 'stain' for segmentation
mp = primary_images.max_proj(Axes.CH, Axes.ZPLANE)
mp = primary_images._max_proj(Axes.CH, Axes.ZPLANE)
mp_numpy = mp._squeezed_numpy(Axes.CH, Axes.ZPLANE)
stain = np.mean(mp_numpy, axis=0)
stain = stain / stain.max()
Expand All @@ -87,7 +87,7 @@ def run(
disk_size_markers = None
disk_size_mask = None

nuclei_mp = nuclei.max_proj(Axes.ROUND, Axes.CH, Axes.ZPLANE)
nuclei_mp = nuclei._max_proj(Axes.ROUND, Axes.CH, Axes.ZPLANE)
nuclei__mp_numpy = nuclei_mp._squeezed_numpy(Axes.ROUND, Axes.CH, Axes.ZPLANE)
self._segmentation_instance = _WatershedSegmenter(nuclei__mp_numpy, stain)
label_image = self._segmentation_instance.segment(
Expand Down
10 changes: 6 additions & 4 deletions starfish/core/imagestack/imagestack.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from tqdm import tqdm

from starfish.core.config import StarfishConfig
from starfish.core.errors import DataFormatWarning
from starfish.core.errors import DataFormatWarning, DeprecatedAPIError
from starfish.core.imagestack import indexing_utils
from starfish.core.imagestack.parser import TileCollectionData, TileKey
from starfish.core.imagestack.parser.crop import CropParameters, CroppedTileCollectionData
Expand Down Expand Up @@ -1136,6 +1136,9 @@ def tile_opener(tileset_path: Path, tile, ext):
tile_format=tile_format)

def max_proj(self, *dims: Axes) -> "ImageStack":
raise DeprecatedAPIError("Please Filter.MaxProject to do max projection operations.")

def _max_proj(self, *dims: Axes) -> "ImageStack":
"""return a max projection over one or more axis of the image tensor
Parameters
Expand All @@ -1147,9 +1150,8 @@ def max_proj(self, *dims: Axes) -> "ImageStack":
-------
np.ndarray :
max projection
"""
max_projection = self._data.max([dim.value for dim in dims])
max_projection = self.xarray.max([dim.value for dim in dims])
max_projection = max_projection.expand_dims(tuple(dim.value for dim in dims))
max_projection = max_projection.transpose(*self.xarray.dims)
physical_coords: MutableMapping[Coordinates, Sequence[Number]] = {}
Expand All @@ -1160,7 +1162,7 @@ def max_proj(self, *dims: Axes) -> "ImageStack":
if axis in dims:
# this axis was projected out of existence.
assert coord.value not in max_projection.coords
physical_coords[coord] = [np.average(self._data.coords[coord.value])]
physical_coords[coord] = [np.average(self.xarray.coords[coord.value])]
else:
physical_coords[coord] = max_projection.coords[coord.value]
max_proj_stack = ImageStack.from_numpy(max_projection.values, coordinates=physical_coords)
Expand Down
10 changes: 7 additions & 3 deletions starfish/core/imagestack/test/test_max_proj.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from starfish import data
from starfish.core.types import Axes, PhysicalCoordinateTypes
from starfish.image import Filter
from .factories import imagestack_with_coords_factory
from .imagestack_test_utils import verify_physical_coordinates
from ..imagestack import ImageStack
Expand All @@ -13,8 +14,9 @@ def test_max_projection_preserves_dtype():
original_dtype = np.float32
array = np.ones((2, 2, 2), dtype=original_dtype)
image = ImageStack.from_numpy(array.reshape((1, 1, 2, 2, 2)))
max_projector = Filter.MaxProject(dims={Axes.CH, Axes.ROUND, Axes.ZPLANE})

max_projection = image.max_proj(Axes.CH, Axes.ROUND, Axes.ZPLANE)
max_projection = max_projector.run(image)
assert max_projection.xarray.dtype == original_dtype


Expand All @@ -24,9 +26,11 @@ def test_max_projection_preserves_dtype():


def test_max_projection_preserves_coordinates():
max_projector = Filter.MaxProject(dims={Axes.CH, Axes.ROUND, Axes.ZPLANE})

e = data.ISS(use_test_data=True)
nuclei = e.fov().get_image('nuclei')
nuclei_proj = nuclei.max_proj(Axes.ROUND, Axes.CH, Axes.ZPLANE)
nuclei_proj = max_projector.run(nuclei)
# Since this data already has only 1 round, 1 ch, 1 zplane
# let's just assert that the max_proj operation didn't change anything
assert nuclei.xarray.equals(nuclei_proj.xarray)
Expand All @@ -44,6 +48,6 @@ def test_max_projection_preserves_coordinates():

stack = imagestack_with_coords_factory(stack_shape, physical_coords)

stack_proj = stack.max_proj(Axes.ROUND, Axes.CH, Axes.ZPLANE)
stack_proj = max_projector.run(stack)
expected_z = np.average(Z_COORDS)
verify_physical_coordinates(stack_proj, X_COORDS, Y_COORDS, expected_z)
2 changes: 1 addition & 1 deletion starfish/core/spots/_detect_spots/detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def detect_spots(data_stack: ImageStack,

if reference_image is not None:
if reference_image_max_projection_axes is not None:
reference_image = reference_image.max_proj(*reference_image_max_projection_axes)
reference_image = reference_image._max_proj(*reference_image_max_projection_axes)
data_image = reference_image._squeezed_numpy(*reference_image_max_projection_axes)
else:
data_image = reference_image.xarray
Expand Down

0 comments on commit ea9ae4d

Please sign in to comment.