-
Notifications
You must be signed in to change notification settings - Fork 6
/
_viewer.py
532 lines (454 loc) · 21.2 KB
/
_viewer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
from __future__ import annotations
from collections import defaultdict
from collections.abc import Iterable, Mapping, Sequence
from itertools import cycle
from typing import TYPE_CHECKING, Literal, cast
import cmap
import numpy as np
from qtpy.QtWidgets import QHBoxLayout, QLabel, QPushButton, QVBoxLayout, QWidget
from superqt import QCollapsible, QElidingLabel, QIconifyIcon, ensure_main_thread
from superqt.utils import qthrottled, signals_blocked
from ndv.viewer._components import (
ChannelMode,
ChannelModeButton,
DimToggleButton,
QSpinner,
)
from ._backends import get_canvas
from ._data_wrapper import DataWrapper
from ._dims_slider import DimsSliders
from ._lut_control import LutControl
if TYPE_CHECKING:
from collections.abc import Hashable
from concurrent.futures import Future
from typing import Any, Callable, TypeAlias
from qtpy.QtGui import QCloseEvent
from ._dims_slider import DimKey, Indices, Sizes
from ._protocols import PCanvas, PImageHandle
ImgKey: TypeAlias = Hashable
# any mapping of dimensions to sizes
SizesLike: TypeAlias = Sizes | Iterable[int | tuple[DimKey, int] | Sequence]
MID_GRAY = "#888888"
GRAYS = cmap.Colormap("gray")
DEFAULT_COLORMAPS = [
cmap.Colormap("green"),
cmap.Colormap("magenta"),
cmap.Colormap("cyan"),
cmap.Colormap("yellow"),
cmap.Colormap("red"),
cmap.Colormap("blue"),
cmap.Colormap("cubehelix"),
cmap.Colormap("gray"),
]
ALL_CHANNELS = slice(None)
class NDViewer(QWidget):
"""A viewer for ND arrays.
This widget displays a single slice from an ND array (or a composite of slices in
different colormaps). The widget provides sliders to select the slice to display,
and buttons to control the display mode of the channels.
An important concept in this widget is the "index". The index is a mapping of
dimensions to integers or slices that define the slice of the data to display. For
example, a numpy slice of `[0, 1, 5:10]` would be represented as
`{0: 0, 1: 1, 2: slice(5, 10)}`, but dimensions can also be named, e.g.
`{'t': 0, 'c': 1, 'z': slice(5, 10)}`. The index is used to select the data from
the datastore, and to determine the position of the sliders.
The flow of data is as follows:
- The user sets the data using the `set_data` method. This will set the number
and range of the sliders to the shape of the data, and display the first slice.
- The user can then use the sliders to select the slice to display. The current
slice is defined as a `Mapping` of `{dim -> int|slice}` and can be retrieved
with the `_dims_sliders.value()` method. To programmatically set the current
position, use the `setIndex` method. This will set the values of the sliders,
which in turn will trigger the display of the new slice via the
`_update_data_for_index` method.
- `_update_data_for_index` is an asynchronous method that retrieves the data for
the given index from the datastore (using `_isel`) and queues the
`_on_data_slice_ready` method to be called when the data is ready. The logic
for extracting data from the datastore is defined in `_data_wrapper.py`, which
handles idiosyncrasies of different datastores (e.g. xarray, tensorstore, etc).
- `_on_data_slice_ready` is called when the data is ready, and updates the image.
Note that if the slice is multidimensional, the data will be reduced to 2D using
max intensity projection (and double-clicking on any given dimension slider will
turn it into a range slider allowing a projection to be made over that dimension).
- The image is displayed on the canvas, which is an object that implements the
`PCanvas` protocol (mostly, it has an `add_image` method that returns a handle
to the added image that can be used to update the data and display). This
small abstraction allows for various backends to be used (e.g. vispy, pygfx, etc).
Parameters
----------
data : Any
The data to display. This can be any duck-like ND array, including numpy, dask,
xarray, jax, tensorstore, zarr, etc. You can add support for new datastores by
subclassing `DataWrapper` and implementing the required methods. See
`DataWrapper` for more information.
parent : QWidget, optional
The parent widget of this widget.
channel_axis : Hashable, optional
The axis that represents the channels in the data. If not provided, this will
be guessed from the data.
channel_mode : ChannelMode, optional
The initial mode for displaying the channels. If not provided, this will be
set to ChannelMode.MONO.
"""
def __init__(
self,
data: DataWrapper | Any,
*,
colormaps: Iterable[cmap._colormap.ColorStopsLike] | None = None,
parent: QWidget | None = None,
channel_axis: DimKey | None = None,
channel_mode: ChannelMode | str = ChannelMode.MONO,
):
super().__init__(parent=parent)
# ATTRIBUTES ----------------------------------------------------
# mapping of key to a list of objects that control image nodes in the canvas
self._img_handles: defaultdict[ImgKey, list[PImageHandle]] = defaultdict(list)
# mapping of same keys to the LutControl objects control image display props
self._lut_ctrls: dict[ImgKey, LutControl] = {}
# the set of dimensions we are currently visualizing (e.g. XY)
# this is used to control which dimensions have sliders and the behavior
# of isel when selecting data from the datastore
self._visualized_dims: set[DimKey] = set()
# the axis that represents the channels in the data
self._channel_axis = channel_axis
self._channel_mode: ChannelMode = None # type: ignore # set in set_channel_mode
# colormaps that will be cycled through when displaying composite images
# TODO: allow user to set this
if colormaps is not None:
self._cmaps = [cmap.Colormap(c) for c in colormaps]
else:
self._cmaps = DEFAULT_COLORMAPS
self._cmap_cycle = cycle(self._cmaps)
# the last future that was created by _update_data_for_index
self._last_future: Future | None = None
# number of dimensions to display
self._ndims: Literal[2, 3] = 2
# WIDGETS ----------------------------------------------------
# the button that controls the display mode of the channels
self._channel_mode_btn = ChannelModeButton(self)
self._channel_mode_btn.clicked.connect(self.set_channel_mode)
# button to reset the zoom of the canvas
self._set_range_btn = QPushButton(
QIconifyIcon("fluent:full-screen-maximize-24-filled"), "", self
)
self._set_range_btn.clicked.connect(self._on_set_range_clicked)
# button to change number of displayed dimensions
self._ndims_btn = DimToggleButton(self)
self._ndims_btn.clicked.connect(self.toggle_3d)
# place to display dataset summary
self._data_info_label = QElidingLabel("", parent=self)
self._progress_spinner = QSpinner(self)
# place to display arbitrary text
self._hover_info_label = QLabel("", self)
# the canvas that displays the images
self._canvas: PCanvas = get_canvas()(self._hover_info_label.setText)
self._canvas.set_ndim(self._ndims)
# the sliders that control the index of the displayed image
self._dims_sliders = DimsSliders(self)
self._dims_sliders.valueChanged.connect(
qthrottled(self._update_data_for_index, 20, leading=True)
)
self._lut_drop = QCollapsible("LUTs", self)
self._lut_drop.setCollapsedIcon(QIconifyIcon("bi:chevron-down", color=MID_GRAY))
self._lut_drop.setExpandedIcon(QIconifyIcon("bi:chevron-up", color=MID_GRAY))
lut_layout = cast("QVBoxLayout", self._lut_drop.layout())
lut_layout.setContentsMargins(0, 1, 0, 1)
lut_layout.setSpacing(0)
if (
hasattr(self._lut_drop, "_content")
and (layout := self._lut_drop._content.layout()) is not None
):
layout.setContentsMargins(0, 0, 0, 0)
layout.setSpacing(0)
# LAYOUT -----------------------------------------------------
self._btns = btns = QHBoxLayout()
btns.setContentsMargins(0, 0, 0, 0)
btns.setSpacing(0)
btns.addStretch()
btns.addWidget(self._channel_mode_btn)
btns.addWidget(self._ndims_btn)
btns.addWidget(self._set_range_btn)
info = QHBoxLayout()
info.setContentsMargins(0, 0, 0, 2)
info.setSpacing(0)
info.addWidget(self._data_info_label)
info.addWidget(self._progress_spinner)
layout = QVBoxLayout(self)
layout.setSpacing(2)
layout.setContentsMargins(6, 6, 6, 6)
layout.addLayout(info)
layout.addWidget(self._canvas.qwidget(), 1)
layout.addWidget(self._hover_info_label)
layout.addWidget(self._dims_sliders)
layout.addWidget(self._lut_drop)
layout.addLayout(btns)
# SETUP ------------------------------------------------------
self.set_channel_mode(channel_mode)
if data is not None:
self.set_data(data)
# ------------------- PUBLIC API ----------------------------
@property
def dims_sliders(self) -> DimsSliders:
"""Return the DimsSliders widget."""
return self._dims_sliders
@property
def data_wrapper(self) -> DataWrapper:
"""Return the DataWrapper object around the datastore."""
return self._data_wrapper
@property
def data(self) -> Any:
"""Return the data backing the view."""
return self._data_wrapper.data
@data.setter
def data(self, data: Any) -> None:
"""Set the data backing the view."""
raise AttributeError("Cannot set data directly. Use `set_data` method.")
def set_data(
self,
data: DataWrapper | Any,
channel_axis: int | None = None,
visualized_dims: Iterable[DimKey] | None = None,
) -> None:
"""Set the datastore, and, optionally, the sizes of the data."""
# store the data
self._data_wrapper = DataWrapper.create(data)
# set channel axis
if channel_axis is not None:
self._channel_axis = channel_axis
elif self._channel_axis is None:
self._channel_axis = self._data_wrapper.guess_channel_axis()
# update the dimensions we are visualizing
if visualized_dims is None:
sizes = self._data_wrapper.sizes()
visualized_dims = list(sizes)[-self._ndims :]
self.set_visualized_dims(visualized_dims)
# update the range of all the sliders to match the sizes we set above
with signals_blocked(self._dims_sliders):
self.update_slider_ranges()
# redraw
self.setIndex({})
# update the data info label
self._data_info_label.setText(self._data_wrapper.summary_info())
def set_visualized_dims(self, dims: Iterable[DimKey]) -> None:
"""Set the dimensions that will be visualized.
This dims will NOT have sliders associated with them.
"""
self._visualized_dims = set(dims)
for d in self._dims_sliders._sliders:
self._dims_sliders.set_dimension_visible(d, d not in self._visualized_dims)
for d in self._visualized_dims:
self._dims_sliders.set_dimension_visible(d, False)
def update_slider_ranges(
self, mins: SizesLike | None = None, maxes: SizesLike | None = None
) -> None:
"""Set the maximum values of the sliders.
If `sizes` is not provided, sizes will be inferred from the datastore.
This is mostly here as a public way to reset the
"""
if maxes is None:
maxes = self._data_wrapper.sizes()
else:
maxes = _to_sizes(maxes)
self._dims_sliders.setMaxima({k: v - 1 for k, v in maxes.items()})
if mins is not None:
self._dims_sliders.setMinima(_to_sizes(mins))
# FIXME: this needs to be moved and made user-controlled
for dim in list(maxes.keys())[-self._ndims :]:
self._dims_sliders.set_dimension_visible(dim, False)
def toggle_3d(self) -> None:
self.set_ndim(3 if self._ndims == 2 else 2)
def set_ndim(self, ndim: Literal[2, 3]) -> None:
"""Set the number of dimensions to display."""
self._ndims = ndim
self._canvas.set_ndim(ndim)
# set the visibility of the last non-channel dimension
sizes = list(self._data_wrapper.sizes())
if self._channel_axis is not None:
sizes = [x for x in sizes if x != self._channel_axis]
if len(sizes) >= 3:
dim3 = sizes[-3]
self._dims_sliders.set_dimension_visible(dim3, True if ndim == 2 else False)
# clear image handles and redraw
if self._img_handles:
self._clear_images()
self._update_data_for_index(self._dims_sliders.value())
def set_channel_mode(self, mode: ChannelMode | str | None = None) -> None:
"""Set the mode for displaying the channels.
In "composite" mode, the channels are displayed as a composite image, using
self._channel_axis as the channel axis. In "grayscale" mode, each channel is
displayed separately. (If mode is None, the current value of the
channel_mode_picker button is used)
"""
if mode is None or isinstance(mode, bool):
mode = self._channel_mode_btn.mode()
else:
mode = ChannelMode(mode)
self._channel_mode_btn.setMode(mode)
if mode == getattr(self, "_channel_mode", None):
return
self._channel_mode = mode
self._cmap_cycle = cycle(self._cmaps) # reset the colormap cycle
if self._channel_axis is not None:
# set the visibility of the channel slider
self._dims_sliders.set_dimension_visible(
self._channel_axis, mode != ChannelMode.COMPOSITE
)
if self._img_handles:
self._clear_images()
self._update_data_for_index(self._dims_sliders.value())
def setIndex(self, index: Indices) -> None:
"""Set the index of the displayed image."""
self._dims_sliders.setValue(index)
# ------------------- PRIVATE METHODS ----------------------------
def _on_set_range_clicked(self) -> None:
# using method to swallow the parameter passed by _set_range_btn.clicked
self._canvas.set_range()
def _image_key(self, index: Indices) -> ImgKey:
"""Return the key for image handle(s) corresponding to `index`."""
if self._channel_mode == ChannelMode.COMPOSITE:
val = index.get(self._channel_axis, 0)
if isinstance(val, slice):
return (val.start, val.stop)
return val
return 0
def _update_data_for_index(self, index: Indices) -> None:
"""Retrieve data for `index` from datastore and update canvas image(s).
This will pull the data from the datastore using the given index, and update
the image handle(s) with the new data. This method is *asynchronous*. It
makes a request for the new data slice and queues _on_data_future_done to be
called when the data is ready.
"""
if (
self._channel_axis is not None
and self._channel_mode == ChannelMode.COMPOSITE
and self._channel_axis in (sizes := self._data_wrapper.sizes())
):
indices: list[Indices] = [
{**index, self._channel_axis: i}
for i in range(sizes[self._channel_axis])
]
else:
indices = [index]
if self._last_future:
self._last_future.cancel()
# don't request any dimensions that are not visualized
indices = [
{k: v for k, v in idx.items() if k not in self._visualized_dims}
for idx in indices
]
try:
self._last_future = f = self._data_wrapper.isel_async(indices)
except Exception as e:
raise type(e)(f"Failed to index data with {index}: {e}") from e
f.add_done_callback(self._on_data_slice_ready)
self._progress_spinner.show()
def closeEvent(self, a0: QCloseEvent | None) -> None:
if self._last_future is not None:
self._last_future.cancel()
self._last_future = None
super().closeEvent(a0)
@ensure_main_thread # type: ignore
def _on_data_slice_ready(
self, future: Future[Iterable[tuple[Indices, np.ndarray]]]
) -> None:
"""Update the displayed image for the given index.
Connected to the future returned by _isel.
"""
# NOTE: removing the reference to the last future here is important
# because the future has a reference to this widget in its _done_callbacks
# which will prevent the widget from being garbage collected if the future
self._last_future = None
self._progress_spinner.hide()
if future.cancelled():
return
for idx, datum in future.result():
self._update_canvas_data(datum, idx)
self._canvas.refresh()
def _update_canvas_data(self, data: np.ndarray, index: Indices) -> None:
"""Actually update the image handle(s) with the (sliced) data.
By this point, data should be sliced from the underlying datastore. Any
dimensions remaining that are more than the number of visualized dimensions
(currently just 2D) will be reduced using max intensity projection (currently).
"""
imkey = self._image_key(index)
datum = self._reduce_data_for_display(data)
if handles := self._img_handles[imkey]:
for handle in handles:
handle.data = datum
if ctrl := self._lut_ctrls.get(imkey, None):
ctrl.update_autoscale()
else:
cm = (
next(self._cmap_cycle)
if self._channel_mode == ChannelMode.COMPOSITE
else GRAYS
)
if datum.ndim == 2:
handles.append(self._canvas.add_image(datum, cmap=cm))
elif datum.ndim == 3:
handles.append(self._canvas.add_volume(datum, cmap=cm))
if imkey not in self._lut_ctrls:
channel_name = self._get_channel_name(index)
self._lut_ctrls[imkey] = c = LutControl(
channel_name,
handles,
self,
cmaplist=self._cmaps + DEFAULT_COLORMAPS,
)
self._lut_drop.addWidget(c)
def _get_channel_name(self, index: Indices) -> str:
c = index.get(self._channel_axis, 0)
return f"Ch {c}" # TODO: get name from user
def _reduce_data_for_display(
self, data: np.ndarray, reductor: Callable[..., np.ndarray] = np.max
) -> np.ndarray:
"""Reduce the number of dimensions in the data for display.
This function takes a data array and reduces the number of dimensions to
the max allowed for display. The default behavior is to reduce the smallest
dimensions, using np.max. This can be improved in the future.
This also coerces 64-bit data to 32-bit data.
"""
# TODO
# - allow dimensions to control how they are reduced (as opposed to just max)
# - for better way to determine which dims need to be reduced (currently just
# the smallest dims)
data = data.squeeze()
visualized_dims = self._ndims
if extra_dims := data.ndim - visualized_dims:
shapes = sorted(enumerate(data.shape), key=lambda x: x[1])
smallest_dims = tuple(i for i, _ in shapes[:extra_dims])
data = reductor(data, axis=smallest_dims)
if data.dtype.itemsize > 4: # More than 32 bits
if np.issubdtype(data.dtype, np.integer):
data = data.astype(np.int32)
else:
data = data.astype(np.float32)
return data
def _clear_images(self) -> None:
"""Remove all images from the canvas."""
for handles in self._img_handles.values():
for handle in handles:
handle.remove()
self._img_handles.clear()
# clear the current LutControls as well
for c in self._lut_ctrls.values():
cast("QVBoxLayout", self.layout()).removeWidget(c)
c.deleteLater()
self._lut_ctrls.clear()
def _to_sizes(sizes: SizesLike | None) -> Sizes:
"""Coerce `sizes` to a {dimKey -> int} mapping."""
if sizes is None:
return {}
if isinstance(sizes, Mapping):
return {k: int(v) for k, v in sizes.items()}
if not isinstance(sizes, Iterable):
raise TypeError(f"SizeLike must be an iterable or mapping, not: {type(sizes)}")
_sizes: dict[Hashable, int] = {}
for i, val in enumerate(sizes):
if isinstance(val, int):
_sizes[i] = val
elif isinstance(val, Sequence) and len(val) == 2:
_sizes[val[0]] = int(val[1])
else:
raise ValueError(f"Invalid size: {val}. Must be an int or a 2-tuple.")
return _sizes