Skip to content

Commit

Permalink
funcs passed to apply and transform can use positional arguments (#1519)
Browse files Browse the repository at this point in the history
With the existing code, funcs passed to apply and transform can only use keyword arguments. This is unnecessarily limiting, and if we want to support #1446, we need to be able to work with a variety of functions with their own argument calling mechanisms.
  • Loading branch information
Tony Tung authored Sep 16, 2019
1 parent 93560f8 commit 4446835
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 9 deletions.
34 changes: 25 additions & 9 deletions starfish/core/imagestack/imagestack.py
Original file line number Diff line number Diff line change
Expand Up @@ -705,6 +705,7 @@ def _iter_axes(self, axes: Set[Axes]=None) -> Iterator[Mapping[Axes, int]]:
def apply(
self,
func: Callable,
*args,
group_by: Set[Axes]=None,
in_place=False,
verbose: bool=False,
Expand Down Expand Up @@ -772,20 +773,23 @@ def apply(
# create a copy of the ImageStack, call apply on that stack with in_place=True
image_stack = deepcopy(self)
image_stack.apply(
func=func,
func,
*args,
group_by=group_by, in_place=True, verbose=verbose, n_processes=n_processes,
clip_method=clip_method,
**kwargs
)
return image_stack

# wrapper adds a target `data` parameter where the results from func will be stored
# data are clipped or scaled by chunk using preserve_float_range if clip_method != 2
# data are clipped or scaled by chunk using preserve_float_range if
# clip_method != SCALE_BY_IMAGE
bound_func = partial(ImageStack._in_place_apply, func, clip_method=clip_method)

# execute the processing workflow
self.transform(
func=bound_func,
bound_func,
*args,
group_by=group_by,
verbose=verbose,
n_processes=n_processes,
Expand All @@ -799,10 +803,13 @@ def apply(

@staticmethod
def _in_place_apply(
apply_func: Callable[..., Union[xr.DataArray, np.ndarray]], data: np.ndarray,
clip_method: Union[str, Clip], **kwargs
apply_func: Callable[..., Union[xr.DataArray, np.ndarray]],
data: np.ndarray,
*args,
clip_method: Union[str, Clip],
**kwargs
) -> None:
result = apply_func(data, **kwargs)
result = apply_func(data, *args, **kwargs)
if clip_method == Clip.CLIP:
data[:] = preserve_float_range(result, rescale=False)
elif clip_method == Clip.SCALE_BY_CHUNK:
Expand All @@ -813,6 +820,7 @@ def _in_place_apply(
def transform(
self,
func: Callable,
*args,
group_by: Set[Axes]=None,
verbose=False,
n_processes: Optional[int]=None,
Expand Down Expand Up @@ -862,7 +870,13 @@ def transform(
}

mp_applyfunc: Callable = partial(
self._processing_workflow, partial(func, **kwargs), self.xarray.dims, coordinates)
self._processing_workflow,
func,
self.xarray.dims,
coordinates,
args,
kwargs,
)

with Pool(
processes=n_processes,
Expand All @@ -878,9 +892,11 @@ def transform(

@staticmethod
def _processing_workflow(
worker_callable: Callable[[np.ndarray], Any],
worker_callable: Callable,
xarray_dims: Sequence[str],
xarray_coordinates: Mapping[str, np.ndarray],
args: Sequence,
kwargs: Mapping,
selector_and_slice_list: Tuple[Mapping[Axes, int],
Tuple[Union[int, slice], ...]],
):
Expand All @@ -897,7 +913,7 @@ def _processing_workflow(
sliced = data_array.sel(selector_and_slice_list[0])

# pass worker_callable a view into the backing array, which will be overwritten
return worker_callable(sliced) # type: ignore
return worker_callable(sliced, *args, **kwargs) # type: ignore

@property
def tile_metadata(self) -> pd.DataFrame:
Expand Down
9 changes: 9 additions & 0 deletions starfish/core/imagestack/test/test_apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,15 @@ def test_apply():
assert (output.xarray == 0.5).all()


def test_apply_positional():
"""test that apply correctly applies a simple function across 2d tiles of a Stack. Unlike
test_apply, the parameter is passed in as a positional parameter."""
stack = synthetic_stack()
assert (stack.xarray == 1).all()
output = stack.apply(divide, 2, n_processes=1)
assert (output.xarray == 0.5).all()


def test_apply_3d():
"""test that apply correctly applies a simple function across 3d volumes of a Stack"""
stack = synthetic_stack()
Expand Down

0 comments on commit 4446835

Please sign in to comment.