diff --git a/doc/python/imshow.md b/doc/python/imshow.md
index 91e8976363..1b92bfd2f6 100644
--- a/doc/python/imshow.md
+++ b/doc/python/imshow.md
@@ -6,7 +6,7 @@ jupyter:
extension: .md
format_name: markdown
format_version: '1.2'
- jupytext_version: 1.4.2
+ jupytext_version: 1.3.0
kernelspec:
display_name: Python 3
language: python
@@ -20,7 +20,7 @@ jupyter:
name: python
nbconvert_exporter: python
pygments_lexer: ipython3
- version: 3.7.7
+ version: 3.7.3
plotly:
description: How to display image data in Python with Plotly.
display_as: scientific
@@ -399,6 +399,95 @@ for compression_level in range(0, 9):
fig.show()
```
+### Exploring 3-D images, timeseries and sequences of images with `facet_col`
+
+*Introduced in plotly 4.14*
+
+For three-dimensional image datasets, obtained for example by MRI or CT in medical imaging, one can explore the dataset by representing its different planes as facets. The `facet_col` argument specifies along which axis the image is sliced through to make the facets. With `facet_col_wrap`, one can set the maximum number of columns. For image datasets passed as xarrays, it is also possible to specify the axis by its name (label), thus passing a string to `facet_col`.
+
+It is recommended to use `binary_string=True` for facetted plots of images in order to keep a small figure size and a short rendering time.
+
+See the [tutorial on facet plots](/python/facet-plots/) for more information on creating and styling facet plots.
+
+```python
+import plotly.express as px
+from skimage import io
+from skimage.data import image_fetcher
+path = image_fetcher.fetch('data/cells.tif')
+data = io.imread(path)
+img = data[20:45:2]
+fig = px.imshow(img, facet_col=0, binary_string=True, facet_col_wrap=5)
+fig.show()
+```
+
+Facets can also be used to represent several images of equal shape, like in the example below where different values of the blurring parameter of a Gaussian filter are compared.
+
+```python
+import plotly.express as px
+import numpy as np
+from skimage import data, filters, img_as_float
+img = data.camera()
+sigmas = [1, 2, 4]
+img_sequence = [filters.gaussian(img, sigma=sigma) for sigma in sigmas]
+fig = px.imshow(np.array(img_sequence), facet_col=0, binary_string=True,
+ labels={'facet_col':'sigma'})
+# Set facet titles
+for i, sigma in enumerate(sigmas):
+ fig.layout.annotations[i]['text'] = 'sigma = %d' %sigma
+fig.show()
+```
+
+```python
+print(fig)
+```
+
+### Exploring 3-D images and timeseries with `animation_frame`
+
+*Introduced in plotly 4.14*
+
+For three-dimensional image datasets, obtained for example by MRI or CT in medical imaging, one can explore the dataset by sliding through its different planes in an animation. The `animation_frame` argument of `px.imshow` sets the axis along which the 3-D image is sliced in the animation.
+
+```python
+import plotly.express as px
+from skimage import io
+from skimage.data import image_fetcher
+path = image_fetcher.fetch('data/cells.tif')
+data = io.imread(path)
+img = data[25:40]
+fig = px.imshow(img, animation_frame=0, binary_string=True)
+fig.show()
+```
+
+### Animations of xarray datasets
+
+*Introduced in plotly 4.14*
+
+For xarray datasets, one can pass either an axis number or an axis name to `animation_frame`. Axis names and coordinates are automatically used for the labels, ticks and animation controls of the figure.
+
+```python
+import plotly.express as px
+import xarray as xr
+# Load xarray from dataset included in the xarray tutorial
+ds = xr.tutorial.open_dataset('air_temperature').air[:20]
+fig = px.imshow(ds, animation_frame='time', zmin=220, zmax=300, color_continuous_scale='RdBu_r')
+fig.show()
+```
+
+### Combining animations and facets
+
+It is possible to view 4-dimensional datasets (for example, 3-D images evolving with time) using a combination of `animation_frame` and `facet_col`.
+
+```python
+import plotly.express as px
+from skimage import io
+from skimage.data import image_fetcher
+path = image_fetcher.fetch('data/cells.tif')
+data = io.imread(path)
+data = data.reshape((15, 4, 256, 256))[5:]
+fig = px.imshow(data, animation_frame=0, facet_col=1, binary_string=True)
+fig.show()
+```
+
#### Reference
See https://plotly.com/python/reference/image/ for more information and chart attribute options!
diff --git a/doc/requirements.txt b/doc/requirements.txt
index 71414c34ee..bf6e717f03 100644
--- a/doc/requirements.txt
+++ b/doc/requirements.txt
@@ -28,4 +28,5 @@ pyarrow
cufflinks==0.17.3
kaleido
umap-learn
+pooch
wget
diff --git a/packages/python/plotly/plotly/express/_imshow.py b/packages/python/plotly/plotly/express/_imshow.py
index 88713e5436..27d1bc7349 100644
--- a/packages/python/plotly/plotly/express/_imshow.py
+++ b/packages/python/plotly/plotly/express/_imshow.py
@@ -1,9 +1,10 @@
import plotly.graph_objs as go
from _plotly_utils.basevalidators import ColorscaleValidator
-from ._core import apply_default_cascade
+from ._core import apply_default_cascade, init_figure, configure_animation_controls
from .imshow_utils import rescale_intensity, _integer_ranges, _integer_types
import pandas as pd
import numpy as np
+import itertools
from plotly.utils import image_array_to_data_uri
try:
@@ -60,6 +61,11 @@ def imshow(
labels={},
x=None,
y=None,
+ animation_frame=None,
+ facet_col=None,
+ facet_col_wrap=None,
+ facet_col_spacing=None,
+ facet_row_spacing=None,
color_continuous_scale=None,
color_continuous_midpoint=None,
range_color=None,
@@ -113,6 +119,26 @@ def imshow(
their lengths must match the lengths of the second and first dimensions of the
img argument. They are auto-populated if the input is an xarray.
+ animation_frame: int or str, optional (default None)
+ axis number along which the image array is sliced to create an animation plot.
+ If `img` is an xarray, `animation_frame` can be the name of one the dimensions.
+
+ facet_col: int or str, optional (default None)
+ axis number along which the image array is sliced to create a facetted plot.
+ If `img` is an xarray, `facet_col` can be the name of one the dimensions.
+
+ facet_col_wrap: int
+ Maximum number of facet columns. Wraps the column variable at this width,
+ so that the column facets span multiple rows.
+ Ignored if `facet_col` is None.
+
+ facet_col_spacing: float between 0 and 1
+ Spacing between facet columns, in paper units. Default is 0.02.
+
+ facet_row_spacing: float between 0 and 1
+ Spacing between facet rows created when ``facet_col_wrap`` is used, in
+ paper units. Default is 0.0.7.
+
color_continuous_scale : str or list of str
colormap used to map scalar data to colors (for a 2D image). This parameter is
not used for RGB or RGBA images. If a string is provided, it should be the name
@@ -204,11 +230,45 @@ def imshow(
args = locals()
apply_default_cascade(args)
labels = labels.copy()
+ nslices_facet = 1
+ if facet_col is not None:
+ if isinstance(facet_col, str):
+ facet_col = img.dims.index(facet_col)
+ nslices_facet = img.shape[facet_col]
+ facet_slices = range(nslices_facet)
+ ncols = int(facet_col_wrap) if facet_col_wrap is not None else nslices_facet
+ nrows = (
+ nslices_facet // ncols + 1
+ if nslices_facet % ncols
+ else nslices_facet // ncols
+ )
+ else:
+ nrows = 1
+ ncols = 1
+ if animation_frame is not None:
+ if isinstance(animation_frame, str):
+ animation_frame = img.dims.index(animation_frame)
+ nslices_animation = img.shape[animation_frame]
+ animation_slices = range(nslices_animation)
+ slice_dimensions = (facet_col is not None) + (
+ animation_frame is not None
+ ) # 0, 1, or 2
+ facet_label = None
+ animation_label = None
img_is_xarray = False
# ----- Define x and y, set labels if img is an xarray -------------------
if xarray_imported and isinstance(img, xarray.DataArray):
+ dims = list(img.dims)
img_is_xarray = True
- y_label, x_label = img.dims[0], img.dims[1]
+ if facet_col is not None:
+ facet_slices = img.coords[img.dims[facet_col]].values
+ _ = dims.pop(facet_col)
+ facet_label = img.dims[facet_col]
+ if animation_frame is not None:
+ animation_slices = img.coords[img.dims[animation_frame]].values
+ _ = dims.pop(animation_frame)
+ animation_label = img.dims[animation_frame]
+ y_label, x_label = dims[0], dims[1]
# np.datetime64 is not handled correctly by go.Heatmap
for ax in [x_label, y_label]:
if np.issubdtype(img.coords[ax].dtype, np.datetime64):
@@ -223,6 +283,10 @@ def imshow(
labels["x"] = x_label
if labels.get("y", None) is None:
labels["y"] = y_label
+ if labels.get("animation_frame", None) is None:
+ labels["animation_frame"] = animation_label
+ if labels.get("facet_col", None) is None:
+ labels["facet_col"] = facet_label
if labels.get("color", None) is None:
labels["color"] = xarray.plot.utils.label_from_attrs(img)
labels["color"] = labels["color"].replace("\n", "
")
@@ -257,10 +321,29 @@ def imshow(
# --------------- Starting from here img is always a numpy array --------
img = np.asanyarray(img)
+ # Reshape array so that animation dimension comes first, then facets, then images
+ if facet_col is not None:
+ img = np.moveaxis(img, facet_col, 0)
+ if animation_frame is not None and animation_frame < facet_col:
+ animation_frame += 1
+ facet_col = True
+ if animation_frame is not None:
+ img = np.moveaxis(img, animation_frame, 0)
+ animation_frame = True
+ args["animation_frame"] = (
+ "animation_frame"
+ if labels.get("animation_frame") is None
+ else labels["animation_frame"]
+ )
+ iterables = ()
+ if animation_frame is not None:
+ iterables += (range(nslices_animation),)
+ if facet_col is not None:
+ iterables += (range(nslices_facet),)
# Default behaviour of binary_string: True for RGB images, False for 2D
if binary_string is None:
- binary_string = img.ndim >= 3 and not is_dataframe
+ binary_string = img.ndim >= (3 + slice_dimensions) and not is_dataframe
# Cast bools to uint8 (also one byte)
if img.dtype == np.bool:
@@ -272,7 +355,7 @@ def imshow(
# -------- Contrast rescaling: either minmax or infer ------------------
if contrast_rescaling is None:
- contrast_rescaling = "minmax" if img.ndim == 2 else "infer"
+ contrast_rescaling = "minmax" if img.ndim == (2 + slice_dimensions) else "infer"
# We try to set zmin and zmax only if necessary, because traces have good defaults
if contrast_rescaling == "minmax":
@@ -288,19 +371,24 @@ def imshow(
if zmin is None and zmax is not None:
zmin = 0
- # For 2d data, use Heatmap trace, unless binary_string is True
- if img.ndim == 2 and not binary_string:
- if y is not None and img.shape[0] != len(y):
+ # For 2d data, use Heatmap trace, unless binary_string is True
+ if img.ndim == 2 + slice_dimensions and not binary_string:
+ y_index = slice_dimensions
+ if y is not None and img.shape[y_index] != len(y):
raise ValueError(
"The length of the y vector must match the length of the first "
+ "dimension of the img matrix."
)
- if x is not None and img.shape[1] != len(x):
+ x_index = slice_dimensions + 1
+ if x is not None and img.shape[x_index] != len(x):
raise ValueError(
"The length of the x vector must match the length of the second "
+ "dimension of the img matrix."
)
- trace = go.Heatmap(x=x, y=y, z=img, coloraxis="coloraxis1")
+ traces = [
+ go.Heatmap(x=x, y=y, z=img[index_tup], coloraxis="coloraxis1", name=str(i))
+ for i, index_tup in enumerate(itertools.product(*iterables))
+ ]
autorange = True if origin == "lower" else "reversed"
layout = dict(yaxis=dict(autorange=autorange))
if aspect == "equal":
@@ -319,7 +407,10 @@ def imshow(
layout["coloraxis1"]["colorbar"] = dict(title_text=labels["color"])
# For 2D+RGB data, use Image trace
- elif img.ndim == 3 and img.shape[-1] in [3, 4] or (img.ndim == 2 and binary_string):
+ elif (
+ img.ndim >= 3
+ and (img.shape[-1] in [3, 4] or slice_dimensions and binary_string)
+ ) or (img.ndim == 2 and binary_string):
rescale_image = True # to check whether image has been modified
if zmin is not None and zmax is not None:
zmin, zmax = (
@@ -366,12 +457,12 @@ def imshow(
if zmin is None and zmax is None: # no rescaling, faster
img_rescaled = img
rescale_image = False
- elif img.ndim == 2:
+ elif img.ndim == 2 + slice_dimensions: # single-channel image
img_rescaled = rescale_intensity(
img, in_range=(zmin[0], zmax[0]), out_range=np.uint8
)
else:
- img_rescaled = np.dstack(
+ img_rescaled = np.stack(
[
rescale_intensity(
img[..., ch],
@@ -379,27 +470,38 @@ def imshow(
out_range=np.uint8,
)
for ch in range(img.shape[-1])
- ]
+ ],
+ axis=-1,
)
- img_str = image_array_to_data_uri(
- img_rescaled,
- backend=binary_backend,
- compression=binary_compression_level,
- ext=binary_format,
- )
- trace = go.Image(source=img_str, x0=x0, y0=y0, dx=dx, dy=dy)
+ img_str = [
+ image_array_to_data_uri(
+ img_rescaled[index_tup],
+ backend=binary_backend,
+ compression=binary_compression_level,
+ ext=binary_format,
+ )
+ for index_tup in itertools.product(*iterables)
+ ]
+
+ traces = [
+ go.Image(source=img_str_slice, name=str(i), x0=x0, y0=y0, dx=dx, dy=dy)
+ for i, img_str_slice in enumerate(img_str)
+ ]
else:
colormodel = "rgb" if img.shape[-1] == 3 else "rgba256"
- trace = go.Image(
- z=img,
- zmin=zmin,
- zmax=zmax,
- colormodel=colormodel,
- x0=x0,
- y0=y0,
- dx=dx,
- dy=dy,
- )
+ traces = [
+ go.Image(
+ z=img[index_tup],
+ zmin=zmin,
+ zmax=zmax,
+ colormodel=colormodel,
+ x0=x0,
+ y0=y0,
+ dx=dx,
+ dy=dy,
+ )
+ for index_tup in itertools.product(*iterables)
+ ]
layout = {}
if origin == "lower" or (dy is not None and dy < 0):
layout["yaxis"] = dict(autorange=True)
@@ -408,19 +510,44 @@ def imshow(
else:
raise ValueError(
"px.imshow only accepts 2D single-channel, RGB or RGBA images. "
- "An image of shape %s was provided" % str(img.shape)
+ "An image of shape %s was provided."
+ "Alternatively, 3- or 4-D single or multichannel datasets can be"
+ "visualized using the `facet_col` or/and `animation_frame` arguments."
+ % str(img.shape)
)
- layout_patch = dict()
+ # Now build figure
+ col_labels = []
+ if facet_col is not None:
+ slice_label = (
+ "facet_col" if labels.get("facet_col") is None else labels["facet_col"]
+ )
+ col_labels = ["%s = %d" % (slice_label, i) for i in facet_slices]
+ fig = init_figure(args, "xy", [], nrows, ncols, col_labels, [])
for attr_name in ["height", "width"]:
if args[attr_name]:
- layout_patch[attr_name] = args[attr_name]
+ layout[attr_name] = args[attr_name]
if args["title"]:
- layout_patch["title_text"] = args["title"]
+ layout["title_text"] = args["title"]
elif args["template"].layout.margin.t is None:
- layout_patch["margin"] = {"t": 60}
- fig = go.Figure(data=trace, layout=layout)
- fig.update_layout(layout_patch)
+ layout["margin"] = {"t": 60}
+
+ frame_list = []
+ for index, trace in enumerate(traces):
+ if (facet_col and index < nrows * ncols) or index == 0:
+ fig.add_trace(trace, row=nrows - index // ncols, col=index % ncols + 1)
+ if animation_frame is not None:
+ for i, index in zip(range(nslices_animation), animation_slices):
+ frame_list.append(
+ dict(
+ data=traces[nslices_facet * i : nslices_facet * (i + 1)],
+ layout=layout,
+ name=str(index),
+ )
+ )
+ if animation_frame:
+ fig.frames = frame_list
+ fig.update_layout(layout)
# Hover name, z or color
if binary_string and rescale_image and not np.all(img == img_rescaled):
# we rescaled the image, hence z is not displayed in hover since it does
@@ -449,5 +576,6 @@ def imshow(
fig.update_xaxes(title_text=labels["x"])
if labels["y"]:
fig.update_yaxes(title_text=labels["y"])
+ configure_animation_controls(args, go.Image, fig)
fig.update_layout(template=args["template"], overwrite=True)
return fig
diff --git a/packages/python/plotly/plotly/tests/test_core/test_px/test_imshow.py b/packages/python/plotly/plotly/tests/test_core/test_px/test_imshow.py
index 313267aacb..912b4151ab 100644
--- a/packages/python/plotly/plotly/tests/test_core/test_px/test_imshow.py
+++ b/packages/python/plotly/plotly/tests/test_core/test_px/test_imshow.py
@@ -172,14 +172,26 @@ def test_zmin_zmax_range_color_source():
assert fig1 == fig2
-def test_imshow_xarray():
+@pytest.mark.parametrize("binary_string", [False, True])
+def test_imshow_xarray(binary_string):
img = np.random.random((20, 30))
da = xr.DataArray(img, dims=["dim_rows", "dim_cols"])
- fig = px.imshow(da)
+ fig = px.imshow(da, binary_string=binary_string)
# Dimensions are used for axis labels and coordinates
assert fig.layout.xaxis.title.text == "dim_cols"
assert fig.layout.yaxis.title.text == "dim_rows"
- assert np.all(np.array(fig.data[0].x) == np.array(da.coords["dim_cols"]))
+ if not binary_string:
+ assert np.all(np.array(fig.data[0].x) == np.array(da.coords["dim_cols"]))
+
+
+def test_imshow_xarray_slicethrough():
+ img = np.random.random((8, 9, 10))
+ da = xr.DataArray(img, dims=["dim_0", "dim_1", "dim_2"])
+ fig = px.imshow(da, animation_frame="dim_0")
+ # Dimensions are used for axis labels and coordinates
+ assert fig.layout.xaxis.title.text == "dim_2"
+ assert fig.layout.yaxis.title.text == "dim_1"
+ assert np.all(np.array(fig.data[0].x) == np.array(da.coords["dim_2"]))
def test_imshow_labels_and_ranges():
@@ -346,3 +358,50 @@ def test_imshow_hovertemplate(binary_string):
fig.data[0].hovertemplate
== "x: %{x}
y: %{y}
color: %{z}"
)
+
+
+@pytest.mark.parametrize("facet_col", [0, 1, 2, -1])
+@pytest.mark.parametrize("binary_string", [False, True])
+def test_facet_col(facet_col, binary_string):
+ img = np.random.randint(255, size=(10, 9, 8))
+ facet_col_wrap = 3
+ fig = px.imshow(
+ img,
+ facet_col=facet_col,
+ facet_col_wrap=facet_col_wrap,
+ binary_string=binary_string,
+ )
+ nslices = img.shape[facet_col]
+ ncols = int(facet_col_wrap)
+ nrows = nslices // ncols + 1 if nslices % ncols else nslices // ncols
+ nmax = ncols * nrows
+ assert "yaxis%d" % nmax in fig.layout
+ assert "yaxis%d" % (nmax + 1) not in fig.layout
+ assert len(fig.data) == nslices
+
+
+@pytest.mark.parametrize("animation_frame", [0, 1, 2, -1])
+@pytest.mark.parametrize("binary_string", [False, True])
+def test_animation_frame_grayscale(animation_frame, binary_string):
+ img = np.random.randint(255, size=(10, 9, 8)).astype(np.uint8)
+ fig = px.imshow(img, animation_frame=animation_frame, binary_string=binary_string,)
+ nslices = img.shape[animation_frame]
+ assert len(fig.frames) == nslices
+
+
+@pytest.mark.parametrize("animation_frame", [0, 1, 2])
+@pytest.mark.parametrize("binary_string", [False, True])
+def test_animation_frame_rgb(animation_frame, binary_string):
+ img = np.random.randint(255, size=(10, 9, 8, 3)).astype(np.uint8)
+ fig = px.imshow(img, animation_frame=animation_frame, binary_string=binary_string,)
+ nslices = img.shape[animation_frame]
+ assert len(fig.frames) == nslices
+
+
+@pytest.mark.parametrize("binary_string", [False, True])
+def test_animation_and_facet(binary_string):
+ img = np.random.randint(255, size=(10, 9, 8, 7)).astype(np.uint8)
+ fig = px.imshow(img, animation_frame=0, facet_col=1, binary_string=binary_string)
+ nslices = img.shape[0]
+ assert len(fig.frames) == nslices
+ assert len(fig.data) == img.shape[1]