-
-
Notifications
You must be signed in to change notification settings - Fork 2.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Imshow #1855
Merged
Merged
Imshow #1855
Changes from all commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
1e3a6c6
first version of imshow
emmanuelle ca2b326
removed kwargs
emmanuelle 88f37f1
black
emmanuelle 6600b0b
adjust for rebase
emmanuelle 45383e5
removed showticks kw argument in imsho
emmanuelle 986b33b
changed zmax behaviour
emmanuelle File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,134 @@ | ||
import plotly.graph_objs as go | ||
import numpy as np # is it fine to depend on np here? | ||
|
||
_float_types = [] | ||
|
||
# Adapted from skimage.util.dtype | ||
_integer_types = ( | ||
np.byte, | ||
np.ubyte, # 8 bits | ||
np.short, | ||
np.ushort, # 16 bits | ||
np.intc, | ||
np.uintc, # 16 or 32 or 64 bits | ||
np.int_, | ||
np.uint, # 32 or 64 bits | ||
np.longlong, | ||
np.ulonglong, | ||
) # 64 bits | ||
_integer_ranges = {t: (np.iinfo(t).min, np.iinfo(t).max) for t in _integer_types} | ||
|
||
|
||
def _vectorize_zvalue(z): | ||
if z is None: | ||
return z | ||
elif np.isscalar(z): | ||
return [z] * 3 + [1] | ||
elif len(z) == 1: | ||
return list(z) * 3 + [1] | ||
elif len(z) == 3: | ||
return list(z) + [1] | ||
elif len(z) == 4: | ||
return z | ||
else: | ||
raise ValueError( | ||
"zmax can be a scalar, or an iterable of length 1, 3 or 4. " | ||
"A value of %s was passed for zmax." % str(z) | ||
) | ||
|
||
|
||
def _infer_zmax_from_type(img): | ||
dt = img.dtype.type | ||
rtol = 1.05 | ||
if dt in _integer_types: | ||
return _integer_ranges[dt][1] | ||
else: | ||
im_max = img[np.isfinite(img)].max() | ||
if im_max <= 1 * rtol: | ||
return 1 | ||
elif im_max <= 255 * rtol: | ||
return 255 | ||
elif im_max <= 65535 * rtol: | ||
return 65535 | ||
else: | ||
return 2 ** 32 | ||
|
||
|
||
def imshow(img, zmin=None, zmax=None, origin=None, colorscale=None): | ||
""" | ||
Display an image, i.e. data on a 2D regular raster. | ||
|
||
Parameters | ||
---------- | ||
|
||
img: array-like image | ||
The image data. Supported array shapes are | ||
|
||
- (M, N): an image with scalar data. The data is visualized | ||
using a colormap. | ||
- (M, N, 3): an image with RGB values. | ||
- (M, N, 4): an image with RGBA values, i.e. including transparency. | ||
|
||
zmin, zmax : scalar or iterable, optional | ||
zmin and zmax define the scalar range that the colormap covers. By default, | ||
zmin and zmax correspond to the min and max values of the datatype for integer | ||
datatypes (ie [0-255] for uint8 images, [0, 65535] for uint16 images, etc.). For | ||
a multichannel image of floats, the max of the image is computed and zmax is the | ||
smallest power of 256 (1, 255, 65535) greater than this max value, | ||
with a 5% tolerance. For a single-channel image, the max of the image is used. | ||
|
||
origin : str, 'upper' or 'lower' (default 'upper') | ||
position of the [0, 0] pixel of the image array, in the upper left or lower left | ||
corner. The convention 'upper' is typically used for matrices and images. | ||
|
||
colorscale : str | ||
colormap used to map scalar data to colors (for a 2D image). This parameter is not used for | ||
RGB or RGBA images. | ||
|
||
Returns | ||
------- | ||
fig : graph_objects.Figure containing the displayed image | ||
|
||
See also | ||
-------- | ||
|
||
plotly.graph_objects.Image : image trace | ||
plotly.graph_objects.Heatmap : heatmap trace | ||
|
||
Notes | ||
----- | ||
|
||
In order to update and customize the returned figure, use | ||
`go.Figure.update_traces` or `go.Figure.update_layout`. | ||
""" | ||
img = np.asanyarray(img) | ||
# Cast bools to uint8 (also one byte) | ||
if img.dtype == np.bool: | ||
img = 255 * img.astype(np.uint8) | ||
|
||
# For 2d data, use Heatmap trace | ||
if img.ndim == 2: | ||
if colorscale is None: | ||
colorscale = "gray" | ||
trace = go.Heatmap(z=img, zmin=zmin, zmax=zmax, colorscale=colorscale) | ||
autorange = True if origin == "lower" else "reversed" | ||
layout = dict( | ||
xaxis=dict(scaleanchor="y", constrain="domain"), | ||
nicolaskruchten marked this conversation as resolved.
Show resolved
Hide resolved
|
||
yaxis=dict(autorange=autorange, constrain="domain"), | ||
) | ||
# For 2D+RGB data, use Image trace | ||
elif img.ndim == 3 and img.shape[-1] in [3, 4]: | ||
if zmax is None and img.dtype is not np.uint8: | ||
zmax = _infer_zmax_from_type(img) | ||
zmin, zmax = _vectorize_zvalue(zmin), _vectorize_zvalue(zmax) | ||
trace = go.Image(z=img, zmin=zmin, zmax=zmax) | ||
layout = {} | ||
if origin == "lower": | ||
layout["yaxis"] = dict(autorange=True) | ||
else: | ||
raise ValueError( | ||
"px.imshow only accepts 2D grayscale, RGB or RGBA images. " | ||
"An image of shape %s was provided" % str(img.shape) | ||
) | ||
fig = go.Figure(data=trace, layout=layout) | ||
return fig |
98 changes: 98 additions & 0 deletions
98
packages/python/plotly/plotly/tests/test_core/test_px/test_imshow.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
import plotly.express as px | ||
import numpy as np | ||
import pytest | ||
|
||
img_rgb = np.array([[[255, 0, 0], [0, 255, 0], [0, 0, 255]]], dtype=np.uint8) | ||
img_gray = np.arange(100).reshape((10, 10)) | ||
|
||
|
||
def test_rgb_uint8(): | ||
fig = px.imshow(img_rgb) | ||
assert fig.data[0]["zmax"] == (255, 255, 255, 1) | ||
|
||
|
||
def test_vmax(): | ||
for zmax in [ | ||
100, | ||
[100], | ||
(100,), | ||
[100, 100, 100], | ||
(100, 100, 100), | ||
(100, 100, 100, 1), | ||
]: | ||
fig = px.imshow(img_rgb, zmax=zmax) | ||
assert fig.data[0]["zmax"] == (100, 100, 100, 1) | ||
|
||
|
||
def test_automatic_zmax_from_dtype(): | ||
dtypes_dict = { | ||
np.uint8: 2 ** 8 - 1, | ||
np.uint16: 2 ** 16 - 1, | ||
np.float: 1, | ||
np.bool: 255, | ||
} | ||
for key, val in dtypes_dict.items(): | ||
img = np.array([0, 1], dtype=key) | ||
img = np.dstack((img,) * 3) | ||
fig = px.imshow(img) | ||
assert fig.data[0]["zmax"] == (val, val, val, 1) | ||
|
||
|
||
def test_origin(): | ||
for img in [img_rgb, img_gray]: | ||
fig = px.imshow(img, origin="lower") | ||
assert fig.layout.yaxis.autorange == True | ||
fig = px.imshow(img_rgb) | ||
assert fig.layout.yaxis.autorange is None | ||
fig = px.imshow(img_gray) | ||
assert fig.layout.yaxis.autorange == "reversed" | ||
|
||
|
||
def test_colorscale(): | ||
fig = px.imshow(img_gray) | ||
assert fig.data[0].colorscale[0] == (0.0, "rgb(0, 0, 0)") | ||
fig = px.imshow(img_gray, colorscale="Viridis") | ||
assert fig.data[0].colorscale[0] == (0.0, "#440154") | ||
|
||
|
||
def test_wrong_dimensions(): | ||
imgs = [1, np.ones((5,) * 3), np.ones((5,) * 4)] | ||
for img in imgs: | ||
with pytest.raises(ValueError) as err_msg: | ||
fig = px.imshow(img) | ||
|
||
|
||
def test_nan_inf_data(): | ||
imgs = [np.ones((20, 20)), 255 * np.ones((20, 20), dtype=np.uint8)] | ||
zmaxs = [1, 255] | ||
for zmax, img in zip(zmaxs, imgs): | ||
img[0] = 0 | ||
img[10:12] = np.nan | ||
# the case of 2d/heatmap is handled gracefully by the JS trace but I don't know how to check it | ||
fig = px.imshow(np.dstack((img,) * 3)) | ||
assert fig.data[0]["zmax"] == (zmax, zmax, zmax, 1) | ||
|
||
|
||
def test_zmax_floats(): | ||
# RGB | ||
imgs = [ | ||
np.ones((5, 5, 3)), | ||
1.02 * np.ones((5, 5, 3)), | ||
2 * np.ones((5, 5, 3)), | ||
1000 * np.ones((5, 5, 3)), | ||
] | ||
zmaxs = [1, 1, 255, 65535] | ||
for zmax, img in zip(zmaxs, imgs): | ||
fig = px.imshow(img) | ||
assert fig.data[0]["zmax"] == (zmax, zmax, zmax, 1) | ||
# single-channel | ||
imgs = [ | ||
np.ones((5, 5)), | ||
1.02 * np.ones((5, 5)), | ||
2 * np.ones((5, 5)), | ||
1000 * np.ones((5, 5)), | ||
] | ||
for zmax, img in zip(zmaxs, imgs): | ||
fig = px.imshow(img) | ||
print(fig.data[0]["zmax"], zmax) | ||
assert fig.data[0]["zmax"] == None |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it is fine, but we should have a little friendly error message if it's not installed, similar to the pandas one for px, no?
in fact, px depends on numpy but doesn't have a friendly error message... maybe we can do both in a separate PR.
How about: let's open a separate issue to discuss the
numpy
dependency and move forward with this as is :)