Skip to content

Commit

Permalink
Imshow (#1855)
Browse files Browse the repository at this point in the history
  • Loading branch information
emmanuelle authored Nov 7, 2019
1 parent 643f58e commit 96967d7
Show file tree
Hide file tree
Showing 3 changed files with 235 additions and 0 deletions.
3 changes: 3 additions & 0 deletions packages/python/plotly/plotly/express/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@
density_heatmap,
)

from ._imshow import imshow

from ._core import ( # noqa: F401
set_mapbox_access_token,
defaults,
Expand Down Expand Up @@ -75,6 +77,7 @@
"strip",
"histogram",
"choropleth",
"imshow",
"data",
"colors",
"set_mapbox_access_token",
Expand Down
134 changes: 134 additions & 0 deletions packages/python/plotly/plotly/express/_imshow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
import plotly.graph_objs as go
import numpy as np # is it fine to depend on np here?

_float_types = []

# Adapted from skimage.util.dtype
_integer_types = (
np.byte,
np.ubyte, # 8 bits
np.short,
np.ushort, # 16 bits
np.intc,
np.uintc, # 16 or 32 or 64 bits
np.int_,
np.uint, # 32 or 64 bits
np.longlong,
np.ulonglong,
) # 64 bits
_integer_ranges = {t: (np.iinfo(t).min, np.iinfo(t).max) for t in _integer_types}


def _vectorize_zvalue(z):
if z is None:
return z
elif np.isscalar(z):
return [z] * 3 + [1]
elif len(z) == 1:
return list(z) * 3 + [1]
elif len(z) == 3:
return list(z) + [1]
elif len(z) == 4:
return z
else:
raise ValueError(
"zmax can be a scalar, or an iterable of length 1, 3 or 4. "
"A value of %s was passed for zmax." % str(z)
)


def _infer_zmax_from_type(img):
dt = img.dtype.type
rtol = 1.05
if dt in _integer_types:
return _integer_ranges[dt][1]
else:
im_max = img[np.isfinite(img)].max()
if im_max <= 1 * rtol:
return 1
elif im_max <= 255 * rtol:
return 255
elif im_max <= 65535 * rtol:
return 65535
else:
return 2 ** 32


def imshow(img, zmin=None, zmax=None, origin=None, colorscale=None):
"""
Display an image, i.e. data on a 2D regular raster.
Parameters
----------
img: array-like image
The image data. Supported array shapes are
- (M, N): an image with scalar data. The data is visualized
using a colormap.
- (M, N, 3): an image with RGB values.
- (M, N, 4): an image with RGBA values, i.e. including transparency.
zmin, zmax : scalar or iterable, optional
zmin and zmax define the scalar range that the colormap covers. By default,
zmin and zmax correspond to the min and max values of the datatype for integer
datatypes (ie [0-255] for uint8 images, [0, 65535] for uint16 images, etc.). For
a multichannel image of floats, the max of the image is computed and zmax is the
smallest power of 256 (1, 255, 65535) greater than this max value,
with a 5% tolerance. For a single-channel image, the max of the image is used.
origin : str, 'upper' or 'lower' (default 'upper')
position of the [0, 0] pixel of the image array, in the upper left or lower left
corner. The convention 'upper' is typically used for matrices and images.
colorscale : str
colormap used to map scalar data to colors (for a 2D image). This parameter is not used for
RGB or RGBA images.
Returns
-------
fig : graph_objects.Figure containing the displayed image
See also
--------
plotly.graph_objects.Image : image trace
plotly.graph_objects.Heatmap : heatmap trace
Notes
-----
In order to update and customize the returned figure, use
`go.Figure.update_traces` or `go.Figure.update_layout`.
"""
img = np.asanyarray(img)
# Cast bools to uint8 (also one byte)
if img.dtype == np.bool:
img = 255 * img.astype(np.uint8)

# For 2d data, use Heatmap trace
if img.ndim == 2:
if colorscale is None:
colorscale = "gray"
trace = go.Heatmap(z=img, zmin=zmin, zmax=zmax, colorscale=colorscale)
autorange = True if origin == "lower" else "reversed"
layout = dict(
xaxis=dict(scaleanchor="y", constrain="domain"),
yaxis=dict(autorange=autorange, constrain="domain"),
)
# For 2D+RGB data, use Image trace
elif img.ndim == 3 and img.shape[-1] in [3, 4]:
if zmax is None and img.dtype is not np.uint8:
zmax = _infer_zmax_from_type(img)
zmin, zmax = _vectorize_zvalue(zmin), _vectorize_zvalue(zmax)
trace = go.Image(z=img, zmin=zmin, zmax=zmax)
layout = {}
if origin == "lower":
layout["yaxis"] = dict(autorange=True)
else:
raise ValueError(
"px.imshow only accepts 2D grayscale, RGB or RGBA images. "
"An image of shape %s was provided" % str(img.shape)
)
fig = go.Figure(data=trace, layout=layout)
return fig
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import plotly.express as px
import numpy as np
import pytest

img_rgb = np.array([[[255, 0, 0], [0, 255, 0], [0, 0, 255]]], dtype=np.uint8)
img_gray = np.arange(100).reshape((10, 10))


def test_rgb_uint8():
fig = px.imshow(img_rgb)
assert fig.data[0]["zmax"] == (255, 255, 255, 1)


def test_vmax():
for zmax in [
100,
[100],
(100,),
[100, 100, 100],
(100, 100, 100),
(100, 100, 100, 1),
]:
fig = px.imshow(img_rgb, zmax=zmax)
assert fig.data[0]["zmax"] == (100, 100, 100, 1)


def test_automatic_zmax_from_dtype():
dtypes_dict = {
np.uint8: 2 ** 8 - 1,
np.uint16: 2 ** 16 - 1,
np.float: 1,
np.bool: 255,
}
for key, val in dtypes_dict.items():
img = np.array([0, 1], dtype=key)
img = np.dstack((img,) * 3)
fig = px.imshow(img)
assert fig.data[0]["zmax"] == (val, val, val, 1)


def test_origin():
for img in [img_rgb, img_gray]:
fig = px.imshow(img, origin="lower")
assert fig.layout.yaxis.autorange == True
fig = px.imshow(img_rgb)
assert fig.layout.yaxis.autorange is None
fig = px.imshow(img_gray)
assert fig.layout.yaxis.autorange == "reversed"


def test_colorscale():
fig = px.imshow(img_gray)
assert fig.data[0].colorscale[0] == (0.0, "rgb(0, 0, 0)")
fig = px.imshow(img_gray, colorscale="Viridis")
assert fig.data[0].colorscale[0] == (0.0, "#440154")


def test_wrong_dimensions():
imgs = [1, np.ones((5,) * 3), np.ones((5,) * 4)]
for img in imgs:
with pytest.raises(ValueError) as err_msg:
fig = px.imshow(img)


def test_nan_inf_data():
imgs = [np.ones((20, 20)), 255 * np.ones((20, 20), dtype=np.uint8)]
zmaxs = [1, 255]
for zmax, img in zip(zmaxs, imgs):
img[0] = 0
img[10:12] = np.nan
# the case of 2d/heatmap is handled gracefully by the JS trace but I don't know how to check it
fig = px.imshow(np.dstack((img,) * 3))
assert fig.data[0]["zmax"] == (zmax, zmax, zmax, 1)


def test_zmax_floats():
# RGB
imgs = [
np.ones((5, 5, 3)),
1.02 * np.ones((5, 5, 3)),
2 * np.ones((5, 5, 3)),
1000 * np.ones((5, 5, 3)),
]
zmaxs = [1, 1, 255, 65535]
for zmax, img in zip(zmaxs, imgs):
fig = px.imshow(img)
assert fig.data[0]["zmax"] == (zmax, zmax, zmax, 1)
# single-channel
imgs = [
np.ones((5, 5)),
1.02 * np.ones((5, 5)),
2 * np.ones((5, 5)),
1000 * np.ones((5, 5)),
]
for zmax, img in zip(zmaxs, imgs):
fig = px.imshow(img)
print(fig.data[0]["zmax"], zmax)
assert fig.data[0]["zmax"] == None

0 comments on commit 96967d7

Please sign in to comment.