From 1e3a6c6dc15a0c2565dc3cbf627289d4598fdf86 Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Wed, 30 Oct 2019 16:57:59 -0400 Subject: [PATCH 1/6] first version of imshow --- .../python/plotly/plotly/express/_imshow.py | 127 ++++++++++++++++++ .../tests/test_core/test_px/test_imshow.py | 73 ++++++++++ 2 files changed, 200 insertions(+) create mode 100644 packages/python/plotly/plotly/express/_imshow.py create mode 100644 packages/python/plotly/plotly/tests/test_core/test_px/test_imshow.py diff --git a/packages/python/plotly/plotly/express/_imshow.py b/packages/python/plotly/plotly/express/_imshow.py new file mode 100644 index 0000000000..33a15caf20 --- /dev/null +++ b/packages/python/plotly/plotly/express/_imshow.py @@ -0,0 +1,127 @@ +import plotly.graph_objs as go +import numpy as np # is it fine to depend on np here? + +_float_types = [] + +# Adapted from skimage.util.dtype +_integer_types = ( + np.byte, + np.ubyte, # 8 bits + np.short, + np.ushort, # 16 bits + np.intc, + np.uintc, # 16 or 32 or 64 bits + np.int_, + np.uint, # 32 or 64 bits + np.longlong, + np.ulonglong, +) # 64 bits +_integer_ranges = {t: (np.iinfo(t).min, np.iinfo(t).max) for t in _integer_types} + + +def _vectorize_zvalue(z): + if z is None: + return z + elif np.isscalar(z): + return [z] * 3 + [1] + elif len(z) == 1: + return list(z) * 3 + [1] + elif len(z) == 3: + return list(z) + [1] + elif len(z) == 4: + return z + else: + raise ValueError( + "zmax can be a scalar, or an iterable of length 1, 3 or 4. " + "A value of %s was passed for zmax." % str(z) + ) + + +def _infer_zmax_from_type(img): + dt = img.dtype.type + if dt in _integer_types: + return _integer_ranges[dt][1] + else: + return img[np.isfinite(img)].max() + + +def imshow( + img, zmin=None, zmax=None, origin=None, colorscale=None, showticks=True, **kwargs +): + """ + Display an image, i.e. data on a 2D regular raster. + + Parameters + ---------- + + img: array-like image + The image data. Supported array shapes are + + - (M, N): an image with scalar data. The data is visualized + using a colormap. + - (M, N, 3): an image with RGB values. + - (M, N, 4): an image with RGBA values, i.e. including transparency. + + zmin, zmax : scalar or iterable, optional + zmin and zmax define the scalar range that the colormap covers. By default, + zmin and zmax correspond to the min and max values of the datatype for integer + datatypes (ie [0-255] for uint8 images, [0, 65535] for uint16 images, etc.), and + to the min and max values of the image for an image of floats. + + origin : str, 'upper' or 'lower' (default 'upper') + position of the [0, 0] pixel of the image array, in the upper left or lower left + corner. The convention 'upper' is typically used for matrices and images. + + colorscale : str + colormap used to map scalar data to colors (for a 2D image). This parameter is not used for + RGB or RGBA images. + + showticks : bool, default True + if False, no tick labels are shown for pixel indices. + + ** kwargs : additional arguments to be passed to the Heatmap (grayscale) or Image (RGB) trace. + + Returns + ------- + fig : graph_objects.Figure containing the displayed image + + See also + -------- + + graph_objects.Image : image trace + graph_objects.Heatmap : heatmap trace + """ + img = np.asanyarray(img) + # Cast bools to uint8 (also one byte) + if img.dtype == np.bool: + img = 255 * img.astype(np.uint8) + + # For 2d data, use Heatmap trace + if img.ndim == 2: + if colorscale is None: + colorscale = "gray" + trace = go.Heatmap(z=img, zmin=zmin, zmax=zmax, colorscale=colorscale, **kwargs) + autorange = True if origin == "lower" else "reversed" + layout = dict( + xaxis=dict(scaleanchor="y", constrain="domain"), + yaxis=dict(autorange=autorange, constrain="domain"), + ) + # For 2D+RGB data, use Image trace + elif img.ndim == 3 and img.shape[-1] in [3, 4]: + if zmax is None and img.dtype is not np.uint8: + zmax = _infer_zmax_from_type(img) + zmin, zmax = _vectorize_zvalue(zmin), _vectorize_zvalue(zmax) + trace = go.Image(z=img, zmin=zmin, zmax=zmax, **kwargs) + layout = {} + if origin == "lower": + layout["yaxis"] = dict(autorange=True) + else: + raise ValueError( + "px.imshow only accepts 2D grayscale, RGB or RGBA images. " + "An image of shape %s was provided" % str(img.shape) + ) + fig = go.Figure(data=trace, layout=layout) + if not showticks: + fig.update_xaxes(showticklabels=False) + fig.update_yaxes(showticklabels=False) + return fig diff --git a/packages/python/plotly/plotly/tests/test_core/test_px/test_imshow.py b/packages/python/plotly/plotly/tests/test_core/test_px/test_imshow.py new file mode 100644 index 0000000000..ac3c570dce --- /dev/null +++ b/packages/python/plotly/plotly/tests/test_core/test_px/test_imshow.py @@ -0,0 +1,73 @@ +import plotly.express as px +import numpy as np +import pytest + +img_rgb = np.array([[[255, 0, 0], [0, 255, 0], [0, 0, 255]]], dtype=np.uint8) +img_gray = np.arange(100).reshape((10, 10)) + + +def test_rgb_uint8(): + fig = px.imshow(img_rgb) + assert fig.data[0]["zmax"] == (255, 255, 255, 1) + + +def test_vmax(): + for zmax in [ + 100, + [100], + (100,), + [100, 100, 100], + (100, 100, 100), + (100, 100, 100, 1), + ]: + fig = px.imshow(img_rgb, zmax=zmax) + assert fig.data[0]["zmax"] == (100, 100, 100, 1) + + +def test_automatic_zmax_from_dtype(): + dtypes_dict = { + np.uint8: 2 ** 8 - 1, + np.uint16: 2 ** 16 - 1, + np.float: 1, + np.bool: 255, + } + for key, val in dtypes_dict.items(): + img = np.array([0, 1], dtype=key) + img = np.dstack((img,) * 3) + fig = px.imshow(img) + assert fig.data[0]["zmax"] == (val, val, val, 1) + + +def test_origin(): + for img in [img_rgb, img_gray]: + fig = px.imshow(img, origin="lower") + assert fig.layout.yaxis.autorange == True + fig = px.imshow(img_rgb) + assert fig.layout.yaxis.autorange is None + fig = px.imshow(img_gray) + assert fig.layout.yaxis.autorange == "reversed" + + +def test_colorscale(): + fig = px.imshow(img_gray) + assert fig.data[0].colorscale[0] == (0.0, "rgb(0, 0, 0)") + fig = px.imshow(img_gray, colorscale="Viridis") + assert fig.data[0].colorscale[0] == (0.0, "#440154") + + +def test_wrong_dimensions(): + imgs = [1, np.ones((5,) * 3), np.ones((5,) * 4)] + for img in imgs: + with pytest.raises(ValueError) as err_msg: + fig = px.imshow(img) + + +def test_nan_inf_data(): + imgs = [np.ones((20, 20)), 255 * np.ones((20, 20), dtype=np.uint8)] + zmaxs = [1, 255] + for zmax, img in zip(zmaxs, imgs): + img[0] = 0 + img[10:12] = np.nan + # the case of 2d/heatmap is handled gracefully by the JS trace but I don't know how to check it + fig = px.imshow(np.dstack((img,) * 3)) + assert fig.data[0]["zmax"] == (zmax, zmax, zmax, 1) From ca2b326d5c159207d9f3b701dd96f9ad9855d4fc Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Thu, 31 Oct 2019 15:24:51 -0400 Subject: [PATCH 2/6] removed kwargs --- .../python/plotly/plotly/express/_imshow.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/packages/python/plotly/plotly/express/_imshow.py b/packages/python/plotly/plotly/express/_imshow.py index 33a15caf20..738d49803f 100644 --- a/packages/python/plotly/plotly/express/_imshow.py +++ b/packages/python/plotly/plotly/express/_imshow.py @@ -46,8 +46,7 @@ def _infer_zmax_from_type(img): def imshow( - img, zmin=None, zmax=None, origin=None, colorscale=None, showticks=True, **kwargs -): + img, zmin=None, zmax=None, origin=None, colorscale=None, showticks=True): """ Display an image, i.e. data on a 2D regular raster. @@ -79,8 +78,6 @@ def imshow( showticks : bool, default True if False, no tick labels are shown for pixel indices. - ** kwargs : additional arguments to be passed to the Heatmap (grayscale) or Image (RGB) trace. - Returns ------- fig : graph_objects.Figure containing the displayed image @@ -88,8 +85,13 @@ def imshow( See also -------- - graph_objects.Image : image trace - graph_objects.Heatmap : heatmap trace + plotly.graph_objects.Image : image trace + plotly.graph_objects.Heatmap : heatmap trace + + Notes + ----- + + In order to update and customize the returned figure, use `go.Figure.update_traces` or `go.Figure.update_layout`. """ img = np.asanyarray(img) # Cast bools to uint8 (also one byte) @@ -100,7 +102,7 @@ def imshow( if img.ndim == 2: if colorscale is None: colorscale = "gray" - trace = go.Heatmap(z=img, zmin=zmin, zmax=zmax, colorscale=colorscale, **kwargs) + trace = go.Heatmap(z=img, zmin=zmin, zmax=zmax, colorscale=colorscale) autorange = True if origin == "lower" else "reversed" layout = dict( xaxis=dict(scaleanchor="y", constrain="domain"), @@ -111,7 +113,7 @@ def imshow( if zmax is None and img.dtype is not np.uint8: zmax = _infer_zmax_from_type(img) zmin, zmax = _vectorize_zvalue(zmin), _vectorize_zvalue(zmax) - trace = go.Image(z=img, zmin=zmin, zmax=zmax, **kwargs) + trace = go.Image(z=img, zmin=zmin, zmax=zmax) layout = {} if origin == "lower": layout["yaxis"] = dict(autorange=True) From 88f37f1393fb129ac6e85b899c0a9c9365e755d4 Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Thu, 31 Oct 2019 15:29:16 -0400 Subject: [PATCH 3/6] black --- packages/python/plotly/plotly/express/_imshow.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/packages/python/plotly/plotly/express/_imshow.py b/packages/python/plotly/plotly/express/_imshow.py index 738d49803f..ab4d1622b6 100644 --- a/packages/python/plotly/plotly/express/_imshow.py +++ b/packages/python/plotly/plotly/express/_imshow.py @@ -45,8 +45,7 @@ def _infer_zmax_from_type(img): return img[np.isfinite(img)].max() -def imshow( - img, zmin=None, zmax=None, origin=None, colorscale=None, showticks=True): +def imshow(img, zmin=None, zmax=None, origin=None, colorscale=None, showticks=True): """ Display an image, i.e. data on a 2D regular raster. From 6600b0b202b645e300c5e092a909612f1b080615 Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Wed, 6 Nov 2019 09:29:32 -0500 Subject: [PATCH 4/6] adjust for rebase --- packages/python/plotly/plotly/express/__init__.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/packages/python/plotly/plotly/express/__init__.py b/packages/python/plotly/plotly/express/__init__.py index 4ff085156c..3ec382d0f3 100644 --- a/packages/python/plotly/plotly/express/__init__.py +++ b/packages/python/plotly/plotly/express/__init__.py @@ -41,6 +41,8 @@ density_heatmap, ) +from ._imshow import imshow + from ._core import ( # noqa: F401 set_mapbox_access_token, defaults, @@ -75,6 +77,7 @@ "strip", "histogram", "choropleth", + "imshow", "data", "colors", "set_mapbox_access_token", From 45383e576f17c2084fc7a45f83f77d02fcbac661 Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Wed, 6 Nov 2019 11:08:16 -0500 Subject: [PATCH 5/6] removed showticks kw argument in imsho --- packages/python/plotly/plotly/express/_imshow.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/packages/python/plotly/plotly/express/_imshow.py b/packages/python/plotly/plotly/express/_imshow.py index ab4d1622b6..513126d083 100644 --- a/packages/python/plotly/plotly/express/_imshow.py +++ b/packages/python/plotly/plotly/express/_imshow.py @@ -45,7 +45,7 @@ def _infer_zmax_from_type(img): return img[np.isfinite(img)].max() -def imshow(img, zmin=None, zmax=None, origin=None, colorscale=None, showticks=True): +def imshow(img, zmin=None, zmax=None, origin=None, colorscale=None): """ Display an image, i.e. data on a 2D regular raster. @@ -74,9 +74,6 @@ def imshow(img, zmin=None, zmax=None, origin=None, colorscale=None, showticks=Tr colormap used to map scalar data to colors (for a 2D image). This parameter is not used for RGB or RGBA images. - showticks : bool, default True - if False, no tick labels are shown for pixel indices. - Returns ------- fig : graph_objects.Figure containing the displayed image @@ -122,7 +119,4 @@ def imshow(img, zmin=None, zmax=None, origin=None, colorscale=None, showticks=Tr "An image of shape %s was provided" % str(img.shape) ) fig = go.Figure(data=trace, layout=layout) - if not showticks: - fig.update_xaxes(showticklabels=False) - fig.update_yaxes(showticklabels=False) return fig From 986b33b4a7220209034fd5cd005c58a4b3f1c734 Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Thu, 7 Nov 2019 09:09:13 -0500 Subject: [PATCH 6/6] changed zmax behaviour --- .../python/plotly/plotly/express/_imshow.py | 20 ++++++++++++--- .../tests/test_core/test_px/test_imshow.py | 25 +++++++++++++++++++ 2 files changed, 41 insertions(+), 4 deletions(-) diff --git a/packages/python/plotly/plotly/express/_imshow.py b/packages/python/plotly/plotly/express/_imshow.py index 513126d083..e41efbeaa0 100644 --- a/packages/python/plotly/plotly/express/_imshow.py +++ b/packages/python/plotly/plotly/express/_imshow.py @@ -39,10 +39,19 @@ def _vectorize_zvalue(z): def _infer_zmax_from_type(img): dt = img.dtype.type + rtol = 1.05 if dt in _integer_types: return _integer_ranges[dt][1] else: - return img[np.isfinite(img)].max() + im_max = img[np.isfinite(img)].max() + if im_max <= 1 * rtol: + return 1 + elif im_max <= 255 * rtol: + return 255 + elif im_max <= 65535 * rtol: + return 65535 + else: + return 2 ** 32 def imshow(img, zmin=None, zmax=None, origin=None, colorscale=None): @@ -63,8 +72,10 @@ def imshow(img, zmin=None, zmax=None, origin=None, colorscale=None): zmin, zmax : scalar or iterable, optional zmin and zmax define the scalar range that the colormap covers. By default, zmin and zmax correspond to the min and max values of the datatype for integer - datatypes (ie [0-255] for uint8 images, [0, 65535] for uint16 images, etc.), and - to the min and max values of the image for an image of floats. + datatypes (ie [0-255] for uint8 images, [0, 65535] for uint16 images, etc.). For + a multichannel image of floats, the max of the image is computed and zmax is the + smallest power of 256 (1, 255, 65535) greater than this max value, + with a 5% tolerance. For a single-channel image, the max of the image is used. origin : str, 'upper' or 'lower' (default 'upper') position of the [0, 0] pixel of the image array, in the upper left or lower left @@ -87,7 +98,8 @@ def imshow(img, zmin=None, zmax=None, origin=None, colorscale=None): Notes ----- - In order to update and customize the returned figure, use `go.Figure.update_traces` or `go.Figure.update_layout`. + In order to update and customize the returned figure, use + `go.Figure.update_traces` or `go.Figure.update_layout`. """ img = np.asanyarray(img) # Cast bools to uint8 (also one byte) diff --git a/packages/python/plotly/plotly/tests/test_core/test_px/test_imshow.py b/packages/python/plotly/plotly/tests/test_core/test_px/test_imshow.py index ac3c570dce..8b6130b998 100644 --- a/packages/python/plotly/plotly/tests/test_core/test_px/test_imshow.py +++ b/packages/python/plotly/plotly/tests/test_core/test_px/test_imshow.py @@ -71,3 +71,28 @@ def test_nan_inf_data(): # the case of 2d/heatmap is handled gracefully by the JS trace but I don't know how to check it fig = px.imshow(np.dstack((img,) * 3)) assert fig.data[0]["zmax"] == (zmax, zmax, zmax, 1) + + +def test_zmax_floats(): + # RGB + imgs = [ + np.ones((5, 5, 3)), + 1.02 * np.ones((5, 5, 3)), + 2 * np.ones((5, 5, 3)), + 1000 * np.ones((5, 5, 3)), + ] + zmaxs = [1, 1, 255, 65535] + for zmax, img in zip(zmaxs, imgs): + fig = px.imshow(img) + assert fig.data[0]["zmax"] == (zmax, zmax, zmax, 1) + # single-channel + imgs = [ + np.ones((5, 5)), + 1.02 * np.ones((5, 5)), + 2 * np.ones((5, 5)), + 1000 * np.ones((5, 5)), + ] + for zmax, img in zip(zmaxs, imgs): + fig = px.imshow(img) + print(fig.data[0]["zmax"], zmax) + assert fig.data[0]["zmax"] == None