-
Notifications
You must be signed in to change notification settings - Fork 59
/
cell_magic_wand.py
225 lines (186 loc) · 9.14 KB
/
cell_magic_wand.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
###########################################################################
#
# Python implementation of the ImageJ Cell Magic Wand plugin
# (http://www.maxplanckflorida.org/fitzpatricklab/software/cellMagicWand/)
# with modifications to reduce variability due to seed point selection
# and to support edge detection using all z slices of a 3D image
#
# Author: Noah Apthorpe (apthorpe@cs.princeton.edu)
#
# Description: Draws a border within a specified radius
# around a specified center "seed" point
# using a polar transform and a dynamic
# programming edge-following algorithm
#
# Usage: Import and call the cell_magic_wand() function
# or cell_magic_wand_3d () function with
# a source image, radius window, and location of center
# point. Other parameters set as optional arguments.
# Returns a binary mask with 1s inside the detected edge and
# a list of points along the detected edge.
#
###########################################################################
import numpy as np
from scipy.ndimage.interpolation import zoom
from scipy.ndimage.morphology import binary_fill_holes
def coord_polar_to_cart(r, theta, center):
'''Converts polar coordinates around center to Cartesian'''
x = r * np.cos(theta) + center[0]
y = r * np.sin(theta) + center[1]
return x, y
def coord_cart_to_polar(x, y, center):
'''Converts Cartesian coordinates to polar'''
r = np.sqrt((x-center[0])**2 + (y-center[1])**2)
theta = np.arctan2((y-center[1]), (x-center[0]))
return r, theta
def image_cart_to_polar(image, center, min_radius, max_radius, phase_width, zoom_factor=1):
'''Converts an image from cartesian to polar coordinates around center'''
# Upsample image
if zoom_factor != 1:
image = zoom(image, (zoom_factor, zoom_factor), order=4)
center = (center[0]*zoom_factor + zoom_factor/2, center[1]*zoom_factor + zoom_factor/2)
min_radius = min_radius * zoom_factor
max_radius = max_radius * zoom_factor
# pad if necessary
max_x, max_y = image.shape[0], image.shape[1]
pad_dist_x = np.max([(center[0] + max_radius) - max_x, -(center[0] - max_radius)])
pad_dist_y = np.max([(center[1] + max_radius) - max_y, -(center[1] - max_radius)])
pad_dist = int(np.max([0, pad_dist_x, pad_dist_y]))
if pad_dist != 0:
image = np.pad(image, pad_dist, 'constant')
# coordinate conversion
theta, r = np.meshgrid(np.linspace(0, 2*np.pi, phase_width),
np.arange(min_radius, max_radius))
x, y = coord_polar_to_cart(r, theta, center)
x, y = np.round(x), np.round(y)
x, y = x.astype(int), y.astype(int)
polar = image[x, y]
polar.reshape((max_radius - min_radius, phase_width))
return polar
def mask_polar_to_cart(mask, center, min_radius, max_radius, output_shape, zoom_factor=1):
'''Converts a polar binary mask to Cartesian and places in an image of zeros'''
# Account for upsampling
if zoom_factor != 1:
center = (center[0]*zoom_factor + zoom_factor/2, center[1]*zoom_factor + zoom_factor/2)
min_radius = min_radius * zoom_factor
max_radius = max_radius * zoom_factor
output_shape = map(lambda a: a * zoom_factor, output_shape)
# new image
image = np.zeros(output_shape)
# coordinate conversion
theta, r = np.meshgrid(np.linspace(0, 2*np.pi, mask.shape[1]),
np.arange(0, max_radius))
x, y = coord_polar_to_cart(r, theta, center)
x, y = np.round(x), np.round(y)
x, y = x.astype(int), y.astype(int)
x = np.clip(x, 0, image.shape[0]-1)
y = np.clip(y, 0, image.shape[1]-1)
ix,iy = np.meshgrid(np.arange(0,mask.shape[1]), np.arange(0,mask.shape[0]))
image[x,y] = mask
# downsample image
if zoom_factor != 1:
zf = 1/float(zoom_factor)
image = zoom(image, (zf, zf), order=4)
# ensure image remains a filled binary mask
image = (image > 0.5).astype(int)
image = binary_fill_holes(image)
return image
def find_edge_2d(polar, min_radius):
'''Dynamic programming algorithm to find edge given polar image'''
if len(polar.shape) != 2:
raise ValueError("argument to find_edge_2d must be 2D")
# Dynamic programming phase
values_right_shift = np.pad(polar, ((0, 0), (0, 1)), 'constant', constant_values=0)[:, 1:]
values_closeright_shift = np.pad(polar, ((1, 0),(0, 1)), 'constant', constant_values=0)[:-1, 1:]
values_awayright_shift = np.pad(polar, ((0, 1), (0, 1)), 'constant', constant_values=0)[1:, 1:]
values_move = np.zeros((polar.shape[0], polar.shape[1], 3))
values_move[:, :, 2] = np.add(polar, values_closeright_shift) # closeright
values_move[:, :, 1] = np.add(polar, values_right_shift) # right
values_move[:, :, 0] = np.add(polar, values_awayright_shift) # awayright
values = np.amax(values_move, axis=2)
directions = np.argmax(values_move, axis=2)
directions = np.subtract(directions, 1)
directions = np.negative(directions)
# Edge following phase
edge = []
mask = np.zeros(polar.shape)
r_max, r = 0, 0
for i,v in enumerate(values[:,0]):
if v >= r_max:
r, r_max = i, v
edge.append((r+min_radius, 0))
mask[0:r+1, 0] = 1
for t in range(1,polar.shape[1]):
r += directions[r, t-1]
if r >= directions.shape[0]: r = directions.shape[0]-1
if r < 0: r = 0
edge.append((r+min_radius, t))
mask[0:r+1, t] = 1
# add to inside of mask accounting for min_radius
new_mask = np.ones((min_radius+mask.shape[0], mask.shape[1]))
new_mask[min_radius:, :] = mask
return np.array(edge), new_mask
def edge_polar_to_cart(edge, center):
'''Converts a list of polar edge points to a list of cartesian edge points'''
cart_edge = []
for (r,t) in edge:
x, y = coord_polar_to_cart(r, t, center)
cart_edge.append((round(x), round(y)))
return cart_edge
def cell_magic_wand_single_point(image, center, min_radius, max_radius,
roughness=2, zoom_factor=1):
'''Draws a border within a specified radius around a specified center "seed" point
using a polar transform and a dynamic programming edge-following algorithm.
Returns a binary mask with 1s inside the detected edge and
a list of points along the detected edge.'''
if roughness < 1:
roughness = 1
print("roughness must be >= 1, setting roughness to 1")
if min_radius < 0:
min_radius = 0
print("min_radius must be >=0, setting min_radius to 0")
if max_radius <= min_radius:
max_radius = min_radius + 1
print("max_radius must be larger than min_radius, setting max_radius to " + str(max_radius))
if zoom_factor <= 0:
zoom_factor = 1
print("negative zoom_factor not allowed, setting zoom_factor to 1")
phase_width = int(2 * np.pi * max_radius * roughness)
polar_image = image_cart_to_polar(image, center, min_radius, max_radius,
phase_width=phase_width, zoom_factor=zoom_factor)
polar_edge, polar_mask = find_edge_2d(polar_image, min_radius)
cart_edge = edge_polar_to_cart(polar_edge, center)
cart_mask = mask_polar_to_cart(polar_mask, center, min_radius, max_radius,
image.shape, zoom_factor=zoom_factor)
return cart_mask, cart_edge
def cell_magic_wand(image, center, min_radius, max_radius,
roughness=2, zoom_factor=1, center_range=2):
'''Runs the cell magic wand tool on multiple points near the supplied center and
combines the results for a more robust edge detection then provided by the vanilla wand tool.
Returns a binary mask with 1s inside detected edge'''
centers = []
for i in [-center_range, 0, center_range]:
for j in [-center_range, 0, center_range]:
centers.append((center[0]+i, center[1]+j))
masks = np.zeros((image.shape[0], image.shape[1], len(centers)))
for i, c in enumerate(centers):
mask, edge = cell_magic_wand_single_point(image, c, min_radius, max_radius,
roughness=roughness, zoom_factor=zoom_factor)
masks[:,:,i] = mask
mean_mask = np.mean(masks, axis=2)
final_mask = (mean_mask > 0.5).astype(int)
return final_mask
def cell_magic_wand_3d(image_3d, center, min_radius, max_radius,
roughness=2, zoom_factor=1, center_range=2, z_step=1):
'''Robust cell magic wand tool for 3D images with dimensions (z, x, y) - default for tifffile.load.
This functions runs the robust wand tool on each z slice in the image and returns the mean mask
thresholded to 0.5'''
masks = np.zeros((int(image_3d.shape[0]/z_step), image_3d.shape[1], image_3d.shape[2]))
for s in range(int(image_3d.shape[0]/z_step)):
mask = cell_magic_wand(image_3d[s*z_step,:,:], center, min_radius, max_radius,
roughness=roughness, zoom_factor=zoom_factor,
center_range=center_range)
masks[s,:,:] = mask
mean_mask = np.mean(masks, axis=0)
final_mask = (mean_mask > 0.5).astype(int)
return final_mask