Skip to content

Commit

Permalink
Make map/reduce APIs more intuitive (#1686)
Browse files Browse the repository at this point in the history
Right now, specifying a FunctionSource along with function parameters makes for a confusing function call.

For instance, `Map('divide', 2, module=FunctionSource.np)` means the FunctionSource comes last, which is not intuitive.

This adds the ability for `FunctionSource`s to be called and return a Bundle that includes both the package name and the function name.  The call above would then become: `Map(FunctionSource.np('divide'), 2)`.

Backwards compatibility with the prior API is maintained, but a warning is generated.

If the top-level package is provided twice, it is treated as an error.

Test plan: added test cases to cover the new approach, the old approach that should generate the warning, and the ugly combination that should fail.
  • Loading branch information
Tony Tung authored Dec 13, 2019
1 parent 3584c57 commit a15237c
Show file tree
Hide file tree
Showing 22 changed files with 198 additions and 90 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def iss_pipeline(fov, codebook):
)

# detect spots using laplacian of gaussians approach
dots_max = fov.get_image('dots').reduce((Axes.ROUND, Axes.ZPLANE), func="max", module=FunctionSource.np)
dots_max = fov.get_image('dots').reduce((Axes.ROUND, Axes.ZPLANE), func="max")
# locate spots in a reference image
spots = bd.run(reference_image=dots_max, image_stack=filtered)

Expand Down
2 changes: 1 addition & 1 deletion docs/source/_static/tutorials/exec_image_manipulations.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
# Here, we demonstrate selecting the last 50 pixels of (x, y) for a rounds 2 and 3 using the
# :py:meth:`ImageStack.sel` method.

from starfish.types import Axes, FunctionSource
from starfish.types import Axes

cropped_image: starfish.ImageStack = image.sel(
{Axes.ROUND: (2, 3), Axes.X: (30, 80), Axes.Y: (50, 100)}
Expand Down
2 changes: 1 addition & 1 deletion notebooks/BaristaSeq.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@
"source": [
"from starfish.image import Filter\n",
"from starfish.types import FunctionSource\n",
"max_projector = Filter.Reduce((Axes.ZPLANE,), func=\"max\", module=FunctionSource.np)\n",
"max_projector = Filter.Reduce((Axes.ZPLANE,), func=FunctionSource.np(\"max\"))\n",
"z_projected_image = max_projector.run(img)\n",
"z_projected_nissl = max_projector.run(nissl)\n",
"\n",
Expand Down
2 changes: 1 addition & 1 deletion notebooks/ISS.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@
" measurement_type='mean',\n",
")\n",
"\n",
"dots_max = dots.reduce((Axes.ROUND, Axes.ZPLANE), func=\"max\", module=FunctionSource.np)\n",
"dots_max = dots.reduce((Axes.ROUND, Axes.ZPLANE), func=FunctionSource.np(\"max\"))\n",
"spots = bd.run(image_stack=registered_imgs, reference_image=dots_max)\n",
"\n",
"decoder = DecodeSpots.PerRoundMaxChannel(codebook=experiment.codebook)\n",
Expand Down
2 changes: 1 addition & 1 deletion notebooks/MERFISH.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
"\n",
"from starfish import display\n",
"from starfish import data, FieldOfView\n",
"from starfish.types import Axes, Features, FunctionSource\n",
"from starfish.types import Axes, Features\n",
"\n",
"from starfish.util.plot import (\n",
" imshow_plane, intensity_histogram, overlay_spot_calls\n",
Expand Down
2 changes: 1 addition & 1 deletion notebooks/STARmap.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -372,4 +372,4 @@
},
"nbformat": 4,
"nbformat_minor": 2
}
}
2 changes: 1 addition & 1 deletion notebooks/py/BaristaSeq.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@
# EPY: START code
from starfish.image import Filter
from starfish.types import FunctionSource
max_projector = Filter.Reduce((Axes.ZPLANE,), func="max", module=FunctionSource.np)
max_projector = Filter.Reduce((Axes.ZPLANE,), func=FunctionSource.np("max"))
z_projected_image = max_projector.run(img)
z_projected_nissl = max_projector.run(nissl)

Expand Down
2 changes: 1 addition & 1 deletion notebooks/py/ISS.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@
measurement_type='mean',
)

dots_max = dots.reduce((Axes.ROUND, Axes.ZPLANE), func="max", module=FunctionSource.np)
dots_max = dots.reduce((Axes.ROUND, Axes.ZPLANE), func=FunctionSource.np("max"))
spots = bd.run(image_stack=registered_imgs, reference_image=dots_max)

decoder = DecodeSpots.PerRoundMaxChannel(codebook=experiment.codebook)
Expand Down
2 changes: 1 addition & 1 deletion notebooks/py/MERFISH.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

from starfish import display
from starfish import data, FieldOfView
from starfish.types import Axes, Features, FunctionSource
from starfish.types import Axes, Features

from starfish.util.plot import (
imshow_plane, intensity_histogram, overlay_spot_calls
Expand Down
44 changes: 33 additions & 11 deletions starfish/core/image/Filter/map.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import warnings
from typing import (
Optional,
Set,
Union
)

from starfish.core.imagestack.imagestack import _reconcile_clip_and_level, ImageStack
from starfish.core.types import Axes, Clip, FunctionSource, Levels
from starfish.core.types import Axes, Clip, FunctionSource, FunctionSourceBundle, Levels
from ._base import FilterAlgorithm


Expand All @@ -16,18 +17,24 @@ class Map(FilterAlgorithm):
Parameters
----------
func : str
Name of a function in the module specified by the ``module`` parameter to apply across the
dimension(s) specified by dims. The function is resolved by ``getattr(<module>, func)``,
except in the cases of predefined aliases. See :py:class:`FunctionSource` for more
information about aliases.
module : FunctionSource
func : Union[str, FunctionSourceBundle]
Function to apply across the dimension(s) specified by ``dims``.
If this value is a string, then the ``module`` parameter is consulted to determine which
python package is used to find the function. If ``module`` is not specified, then the
default is :py:attr:`FunctionSource.np`.
If this value is a ``FunctionSourceBundle``, then the python package and module name is
obtained from the bundle.
module : Optional[FunctionSource]
Python module that serves as the source of the function. It must be listed as one of the
members of :py:class:`FunctionSource`.
Currently, the supported FunctionSources are:
- ``np``: the top-level package of numpy
- ``scipy``: the top-level package of scipy
This is being deprecated in favor of specifying the function as a ``FunctionSourceBundle``.
in_place : bool
Execute the operation in-place. (default: False)
group_by : Set[Axes]
Expand Down Expand Up @@ -80,16 +87,31 @@ class Map(FilterAlgorithm):

def __init__(
self,
func: str,
func: Union[str, FunctionSourceBundle],
*func_args,
module: FunctionSource = FunctionSource.np,
module: Optional[FunctionSource] = None,
in_place: bool = False,
group_by: Optional[Set[Union[Axes, str]]] = None,
clip_method: Optional[Clip] = None,
level_method: Optional[Levels] = None,
**func_kwargs,
) -> None:
self.func = module._resolve_method(func)
if isinstance(func, str):
if module is not None:
warnings.warn(
f"The module parameter is being deprecated. Use "
f"`func=FunctionSource.{module.name}{func} instead.",
DeprecationWarning)
else:
module = FunctionSource.np
self.func = module(func)
elif isinstance(func, FunctionSourceBundle):
if module is not None:
raise ValueError(
f"When passing in the function as a `FunctionSourceBundle`, module should not "
f"be set."
)
self.func = func
self.in_place = in_place
if group_by is None:
group_by = {Axes.ROUND, Axes.CH, Axes.ZPLANE}
Expand Down Expand Up @@ -122,7 +144,7 @@ def run(

# Apply the reducing function
return stack.apply(
self.func,
self.func.resolve(),
*self.func_args,
group_by=self.group_by,
in_place=self.in_place,
Expand Down
56 changes: 43 additions & 13 deletions starfish/core/image/Filter/reduce.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings
from typing import (
Iterable,
MutableMapping,
Expand All @@ -8,7 +9,16 @@
import numpy as np

from starfish.core.imagestack.imagestack import _reconcile_clip_and_level, ImageStack
from starfish.core.types import ArrayLike, Axes, Clip, Coordinates, FunctionSource, Levels, Number
from starfish.core.types import (
ArrayLike,
Axes,
Clip,
Coordinates,
FunctionSource,
FunctionSourceBundle,
Levels,
Number,
)
from starfish.core.util.levels import levels
from ._base import FilterAlgorithm

Expand All @@ -21,25 +31,31 @@ class Reduce(FilterAlgorithm):
----------
dims : Iterable[Union[Axes, str]]
one or more Axes to reduce over
func : str
Name of a function in the module specified by the ``module`` parameter to apply across the
dimension(s) specified by dims. The function is resolved by ``getattr(<module>, func)``,
except in the cases of predefined aliases. See :py:class:`FunctionSource` for more
information about aliases.
func : Union[str, FunctionSourceBundle]
Function to apply across the dimension(s) specified by ``dims``.
If this value is a string, then the ``module`` parameter is consulted to determine which
python package is used to find the function. If ``module`` is not specified, then the
default is :py:attr:`FunctionSource.np`.
If this value is a ``FunctionSourceBundle``, then the python package and module name is
obtained from the bundle.
Some common examples for the np FunctionSource:
- amax: maximum intensity projection (applies np.amax)
- max: maximum intensity projection (this is an alias for amax and applies np.amax)
- mean: take the mean across the dim(s) (applies np.mean)
- sum: sum across the dim(s) (applies np.sum)
module : FunctionSource
module : Optional[FunctionSource]
Python module that serves as the source of the function. It must be listed as one of the
members of :py:class:`FunctionSource`.
Currently, the supported FunctionSources are:
- ``np``: the top-level package of numpy
- ``scipy``: the top-level package of scipy
This is being deprecated in favor of specifying the function as a ``FunctionSourceBundle``.
clip_method : Optional[Union[str, :py:class:`~starfish.types.Clip`]]
Deprecated method to control the way that data are scaled to retain skimage dtype
requirements that float data fall in [0, 1]. In all modes, data below 0 are set to 0.
Expand Down Expand Up @@ -85,8 +101,7 @@ class Reduce(FilterAlgorithm):
>>> stack = synthetic_stack()
>>> reducer = Filter.Reduce(
{Axes.ROUND},
func="linalg.norm",
module=FunctionSource.scipy,
func=FunctionSource.scipy("linalg.norm"),
ord=2,
)
>>> norm = reducer.run(stack)
Expand All @@ -100,14 +115,29 @@ class Reduce(FilterAlgorithm):
def __init__(
self,
dims: Iterable[Union[Axes, str]],
func: str = "max",
module: FunctionSource = FunctionSource.np,
func: Union[str, FunctionSourceBundle] = "max",
module: Optional[FunctionSource] = None,
clip_method: Optional[Clip] = None,
level_method: Optional[Levels] = None,
**kwargs
) -> None:
self.dims: Iterable[Axes] = set(Axes(dim) for dim in dims)
self.func = module._resolve_method(func)
if isinstance(func, str):
if module is not None:
warnings.warn(
f"The module parameter is being deprecated. Use "
f"`func=FunctionSource.{module.name}{func} instead.",
DeprecationWarning)
else:
module = FunctionSource.np
self.func = module(func)
elif isinstance(func, FunctionSourceBundle):
if module is not None:
raise ValueError(
f"When passing in the function as a `FunctionSourceBundle`, module should not "
f"be set."
)
self.func = func
self.level_method = _reconcile_clip_and_level(clip_method, level_method)
self.kwargs = kwargs

Expand All @@ -134,7 +164,7 @@ def run(

# Apply the reducing function
reduced = stack.xarray.reduce(
self.func, dim=[dim.value for dim in self.dims], **self.kwargs)
self.func.resolve(), dim=[dim.value for dim in self.dims], **self.kwargs)

# Add the reduced dims back and align with the original stack
reduced = reduced.expand_dims(tuple(dim.value for dim in self.dims))
Expand Down
5 changes: 5 additions & 0 deletions starfish/core/image/Filter/test/test_map.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from starfish.core.imagestack.test.factories import synthetic_stack
from starfish.core.types import FunctionSource
from .. import Map


Expand All @@ -9,3 +10,7 @@ def test_map():
mapper = Map("divide", 2)
output = mapper.run(stack)
assert (output.xarray == 0.5).all()

mapper = Map(FunctionSource.np("divide"), 2)
output = mapper.run(stack)
assert (output.xarray == 0.5).all()
22 changes: 22 additions & 0 deletions starfish/core/image/Filter/test/test_reduce.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings
from collections import OrderedDict

import numpy as np
Expand Down Expand Up @@ -103,6 +104,12 @@ def make_expected_image_stack(func):
FunctionSource.scipy,
{'ord': 2},
),
(
make_expected_image_stack('norm'),
FunctionSource.scipy('linalg.norm'),
None,
{'ord': 2},
),
]
)
def test_image_stack_reduce(expected_result, func, module, kwargs):
Expand All @@ -123,6 +130,21 @@ def test_image_stack_reduce(expected_result, func, module, kwargs):
assert np.allclose(reduced.xarray, expected_result.xarray)


def test_image_stack_module_deprecated():
"""Specifying the function as a string and passing in a module should generate a warning."""
with warnings.catch_warnings(record=True) as all_warnings:
Reduce(dims=[Axes.ROUND], func="max", module=FunctionSource.np)

assert DeprecationWarning in (warning.category for warning in all_warnings)


def test_image_stack_module_with_functionsourcebundle():
"""Specifying the function as a FunctionSourceBundle and passing in a module should raise an
Exception."""
with pytest.raises(ValueError):
Reduce(dims=[Axes.ROUND], func=FunctionSource.np("max"), module=FunctionSource.np)


def test_max_projection_preserves_coordinates():
e = data.ISS(use_test_data=True)
nuclei = e.fov().get_image('nuclei')
Expand Down
9 changes: 5 additions & 4 deletions starfish/core/imagestack/imagestack.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
Coordinates,
CoordinateValue,
FunctionSource,
FunctionSourceBundle,
Levels,
Number,
STARFISH_EXTRAS_KEY,
Expand Down Expand Up @@ -1180,8 +1181,8 @@ def _squeezed_numpy(self, *dims: Axes):
def reduce(
self,
dims: Iterable[Union[Axes, str]],
func: str,
module: FunctionSource = FunctionSource.np,
func: Union[str, FunctionSourceBundle],
module: Optional[FunctionSource] = None,
clip_method: Optional[Clip] = None,
level_method: Optional[Levels] = None,
*args,
Expand All @@ -1203,8 +1204,8 @@ def reduce(

def map(
self,
func: str,
module: FunctionSource = FunctionSource.np,
func: Union[str, FunctionSourceBundle],
module: Optional[FunctionSource] = None,
in_place: bool = False,
group_by: Optional[Set[Union[Axes, str]]] = None,
clip_method: Optional[Clip] = None,
Expand Down
Loading

0 comments on commit a15237c

Please sign in to comment.