Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Imshow #1855

Merged
merged 6 commits into from
Nov 7, 2019
Merged

Imshow #1855

Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions packages/python/plotly/plotly/express/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@
density_heatmap,
)

from ._imshow import imshow

from ._core import ( # noqa: F401
set_mapbox_access_token,
defaults,
Expand Down Expand Up @@ -75,6 +77,7 @@
"strip",
"histogram",
"choropleth",
"imshow",
"data",
"colors",
"set_mapbox_access_token",
Expand Down
128 changes: 128 additions & 0 deletions packages/python/plotly/plotly/express/_imshow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
import plotly.graph_objs as go
import numpy as np # is it fine to depend on np here?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is fine, but we should have a little friendly error message if it's not installed, similar to the pandas one for px, no?

in fact, px depends on numpy but doesn't have a friendly error message... maybe we can do both in a separate PR.

How about: let's open a separate issue to discuss the numpy dependency and move forward with this as is :)


_float_types = []

# Adapted from skimage.util.dtype
_integer_types = (
np.byte,
np.ubyte, # 8 bits
np.short,
np.ushort, # 16 bits
np.intc,
np.uintc, # 16 or 32 or 64 bits
np.int_,
np.uint, # 32 or 64 bits
np.longlong,
np.ulonglong,
) # 64 bits
_integer_ranges = {t: (np.iinfo(t).min, np.iinfo(t).max) for t in _integer_types}


def _vectorize_zvalue(z):
if z is None:
return z
elif np.isscalar(z):
return [z] * 3 + [1]
elif len(z) == 1:
return list(z) * 3 + [1]
elif len(z) == 3:
return list(z) + [1]
elif len(z) == 4:
return z
else:
raise ValueError(
"zmax can be a scalar, or an iterable of length 1, 3 or 4. "
"A value of %s was passed for zmax." % str(z)
)


def _infer_zmax_from_type(img):
dt = img.dtype.type
if dt in _integer_types:
return _integer_ranges[dt][1]
else:
return img[np.isfinite(img)].max()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Per our conversation, I think this should return "the smallest power of 255 which is greater than the max".

My rationale here is that if you have a pipeline that's intended to produce output between 0-X, and you use this to display the output of multiple inputs, it would be nice for them to have the same bounds, instead of having the bounds vary. And it feels like the most likely values of X are 1 and 255, and then possibly thereafter some powers thereof, if the data is 16-bit or 32-bit or whatnot :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one thing which bothers me though is the following case: suppose an image is in the [0-1] range but some numerical computation (filter etc.) changes the max to 1 + some small value. Then the zmax will be 255 for a max of say 1.05, and it will look really bad.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can add a multiplicative tolerance factor like 10%

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd be OK with this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about having a keyword argument colorrange (or a better name consistent maybe with other parts of the library) which could be data (take min/max) or dtype (inference based on what we've discussed)? With this I think most people would be happy... Or does this complicate too much the API ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My feeling is that that complicates too much



def imshow(img, zmin=None, zmax=None, origin=None, colorscale=None, showticks=True):
"""
Display an image, i.e. data on a 2D regular raster.

Parameters
----------

img: array-like image
The image data. Supported array shapes are

- (M, N): an image with scalar data. The data is visualized
using a colormap.
- (M, N, 3): an image with RGB values.
- (M, N, 4): an image with RGBA values, i.e. including transparency.

zmin, zmax : scalar or iterable, optional
zmin and zmax define the scalar range that the colormap covers. By default,
zmin and zmax correspond to the min and max values of the datatype for integer
datatypes (ie [0-255] for uint8 images, [0, 65535] for uint16 images, etc.), and
to the min and max values of the image for an image of floats.

origin : str, 'upper' or 'lower' (default 'upper')
position of the [0, 0] pixel of the image array, in the upper left or lower left
corner. The convention 'upper' is typically used for matrices and images.

colorscale : str
colormap used to map scalar data to colors (for a 2D image). This parameter is not used for
RGB or RGBA images.

showticks : bool, default True
nicolaskruchten marked this conversation as resolved.
Show resolved Hide resolved
if False, no tick labels are shown for pixel indices.

Returns
-------
fig : graph_objects.Figure containing the displayed image

See also
--------

plotly.graph_objects.Image : image trace
plotly.graph_objects.Heatmap : heatmap trace

Notes
-----

In order to update and customize the returned figure, use `go.Figure.update_traces` or `go.Figure.update_layout`.
"""
img = np.asanyarray(img)
# Cast bools to uint8 (also one byte)
if img.dtype == np.bool:
img = 255 * img.astype(np.uint8)

# For 2d data, use Heatmap trace
if img.ndim == 2:
if colorscale is None:
colorscale = "gray"
trace = go.Heatmap(z=img, zmin=zmin, zmax=zmax, colorscale=colorscale)
autorange = True if origin == "lower" else "reversed"
layout = dict(
xaxis=dict(scaleanchor="y", constrain="domain"),
nicolaskruchten marked this conversation as resolved.
Show resolved Hide resolved
yaxis=dict(autorange=autorange, constrain="domain"),
)
# For 2D+RGB data, use Image trace
elif img.ndim == 3 and img.shape[-1] in [3, 4]:
if zmax is None and img.dtype is not np.uint8:
zmax = _infer_zmax_from_type(img)
zmin, zmax = _vectorize_zvalue(zmin), _vectorize_zvalue(zmax)
trace = go.Image(z=img, zmin=zmin, zmax=zmax)
layout = {}
if origin == "lower":
layout["yaxis"] = dict(autorange=True)
else:
raise ValueError(
"px.imshow only accepts 2D grayscale, RGB or RGBA images. "
"An image of shape %s was provided" % str(img.shape)
)
fig = go.Figure(data=trace, layout=layout)
if not showticks:
fig.update_xaxes(showticklabels=False)
fig.update_yaxes(showticklabels=False)
return fig
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import plotly.express as px
import numpy as np
import pytest

img_rgb = np.array([[[255, 0, 0], [0, 255, 0], [0, 0, 255]]], dtype=np.uint8)
img_gray = np.arange(100).reshape((10, 10))


def test_rgb_uint8():
fig = px.imshow(img_rgb)
assert fig.data[0]["zmax"] == (255, 255, 255, 1)


def test_vmax():
for zmax in [
100,
[100],
(100,),
[100, 100, 100],
(100, 100, 100),
(100, 100, 100, 1),
]:
fig = px.imshow(img_rgb, zmax=zmax)
assert fig.data[0]["zmax"] == (100, 100, 100, 1)


def test_automatic_zmax_from_dtype():
dtypes_dict = {
np.uint8: 2 ** 8 - 1,
np.uint16: 2 ** 16 - 1,
np.float: 1,
np.bool: 255,
}
for key, val in dtypes_dict.items():
img = np.array([0, 1], dtype=key)
img = np.dstack((img,) * 3)
fig = px.imshow(img)
assert fig.data[0]["zmax"] == (val, val, val, 1)


def test_origin():
for img in [img_rgb, img_gray]:
fig = px.imshow(img, origin="lower")
assert fig.layout.yaxis.autorange == True
fig = px.imshow(img_rgb)
assert fig.layout.yaxis.autorange is None
fig = px.imshow(img_gray)
assert fig.layout.yaxis.autorange == "reversed"


def test_colorscale():
fig = px.imshow(img_gray)
assert fig.data[0].colorscale[0] == (0.0, "rgb(0, 0, 0)")
fig = px.imshow(img_gray, colorscale="Viridis")
assert fig.data[0].colorscale[0] == (0.0, "#440154")


def test_wrong_dimensions():
imgs = [1, np.ones((5,) * 3), np.ones((5,) * 4)]
for img in imgs:
with pytest.raises(ValueError) as err_msg:
fig = px.imshow(img)


def test_nan_inf_data():
imgs = [np.ones((20, 20)), 255 * np.ones((20, 20), dtype=np.uint8)]
zmaxs = [1, 255]
for zmax, img in zip(zmaxs, imgs):
img[0] = 0
img[10:12] = np.nan
# the case of 2d/heatmap is handled gracefully by the JS trace but I don't know how to check it
fig = px.imshow(np.dstack((img,) * 3))
assert fig.data[0]["zmax"] == (zmax, zmax, zmax, 1)