diff --git a/packages/python/plotly/plotly/express/__init__.py b/packages/python/plotly/plotly/express/__init__.py index 4ff085156c..3ec382d0f3 100644 --- a/packages/python/plotly/plotly/express/__init__.py +++ b/packages/python/plotly/plotly/express/__init__.py @@ -41,6 +41,8 @@ density_heatmap, ) +from ._imshow import imshow + from ._core import ( # noqa: F401 set_mapbox_access_token, defaults, @@ -75,6 +77,7 @@ "strip", "histogram", "choropleth", + "imshow", "data", "colors", "set_mapbox_access_token", diff --git a/packages/python/plotly/plotly/express/_imshow.py b/packages/python/plotly/plotly/express/_imshow.py new file mode 100644 index 0000000000..e41efbeaa0 --- /dev/null +++ b/packages/python/plotly/plotly/express/_imshow.py @@ -0,0 +1,134 @@ +import plotly.graph_objs as go +import numpy as np # is it fine to depend on np here? + +_float_types = [] + +# Adapted from skimage.util.dtype +_integer_types = ( + np.byte, + np.ubyte, # 8 bits + np.short, + np.ushort, # 16 bits + np.intc, + np.uintc, # 16 or 32 or 64 bits + np.int_, + np.uint, # 32 or 64 bits + np.longlong, + np.ulonglong, +) # 64 bits +_integer_ranges = {t: (np.iinfo(t).min, np.iinfo(t).max) for t in _integer_types} + + +def _vectorize_zvalue(z): + if z is None: + return z + elif np.isscalar(z): + return [z] * 3 + [1] + elif len(z) == 1: + return list(z) * 3 + [1] + elif len(z) == 3: + return list(z) + [1] + elif len(z) == 4: + return z + else: + raise ValueError( + "zmax can be a scalar, or an iterable of length 1, 3 or 4. " + "A value of %s was passed for zmax." % str(z) + ) + + +def _infer_zmax_from_type(img): + dt = img.dtype.type + rtol = 1.05 + if dt in _integer_types: + return _integer_ranges[dt][1] + else: + im_max = img[np.isfinite(img)].max() + if im_max <= 1 * rtol: + return 1 + elif im_max <= 255 * rtol: + return 255 + elif im_max <= 65535 * rtol: + return 65535 + else: + return 2 ** 32 + + +def imshow(img, zmin=None, zmax=None, origin=None, colorscale=None): + """ + Display an image, i.e. data on a 2D regular raster. + + Parameters + ---------- + + img: array-like image + The image data. Supported array shapes are + + - (M, N): an image with scalar data. The data is visualized + using a colormap. + - (M, N, 3): an image with RGB values. + - (M, N, 4): an image with RGBA values, i.e. including transparency. + + zmin, zmax : scalar or iterable, optional + zmin and zmax define the scalar range that the colormap covers. By default, + zmin and zmax correspond to the min and max values of the datatype for integer + datatypes (ie [0-255] for uint8 images, [0, 65535] for uint16 images, etc.). For + a multichannel image of floats, the max of the image is computed and zmax is the + smallest power of 256 (1, 255, 65535) greater than this max value, + with a 5% tolerance. For a single-channel image, the max of the image is used. + + origin : str, 'upper' or 'lower' (default 'upper') + position of the [0, 0] pixel of the image array, in the upper left or lower left + corner. The convention 'upper' is typically used for matrices and images. + + colorscale : str + colormap used to map scalar data to colors (for a 2D image). This parameter is not used for + RGB or RGBA images. + + Returns + ------- + fig : graph_objects.Figure containing the displayed image + + See also + -------- + + plotly.graph_objects.Image : image trace + plotly.graph_objects.Heatmap : heatmap trace + + Notes + ----- + + In order to update and customize the returned figure, use + `go.Figure.update_traces` or `go.Figure.update_layout`. + """ + img = np.asanyarray(img) + # Cast bools to uint8 (also one byte) + if img.dtype == np.bool: + img = 255 * img.astype(np.uint8) + + # For 2d data, use Heatmap trace + if img.ndim == 2: + if colorscale is None: + colorscale = "gray" + trace = go.Heatmap(z=img, zmin=zmin, zmax=zmax, colorscale=colorscale) + autorange = True if origin == "lower" else "reversed" + layout = dict( + xaxis=dict(scaleanchor="y", constrain="domain"), + yaxis=dict(autorange=autorange, constrain="domain"), + ) + # For 2D+RGB data, use Image trace + elif img.ndim == 3 and img.shape[-1] in [3, 4]: + if zmax is None and img.dtype is not np.uint8: + zmax = _infer_zmax_from_type(img) + zmin, zmax = _vectorize_zvalue(zmin), _vectorize_zvalue(zmax) + trace = go.Image(z=img, zmin=zmin, zmax=zmax) + layout = {} + if origin == "lower": + layout["yaxis"] = dict(autorange=True) + else: + raise ValueError( + "px.imshow only accepts 2D grayscale, RGB or RGBA images. " + "An image of shape %s was provided" % str(img.shape) + ) + fig = go.Figure(data=trace, layout=layout) + return fig diff --git a/packages/python/plotly/plotly/tests/test_core/test_px/test_imshow.py b/packages/python/plotly/plotly/tests/test_core/test_px/test_imshow.py new file mode 100644 index 0000000000..8b6130b998 --- /dev/null +++ b/packages/python/plotly/plotly/tests/test_core/test_px/test_imshow.py @@ -0,0 +1,98 @@ +import plotly.express as px +import numpy as np +import pytest + +img_rgb = np.array([[[255, 0, 0], [0, 255, 0], [0, 0, 255]]], dtype=np.uint8) +img_gray = np.arange(100).reshape((10, 10)) + + +def test_rgb_uint8(): + fig = px.imshow(img_rgb) + assert fig.data[0]["zmax"] == (255, 255, 255, 1) + + +def test_vmax(): + for zmax in [ + 100, + [100], + (100,), + [100, 100, 100], + (100, 100, 100), + (100, 100, 100, 1), + ]: + fig = px.imshow(img_rgb, zmax=zmax) + assert fig.data[0]["zmax"] == (100, 100, 100, 1) + + +def test_automatic_zmax_from_dtype(): + dtypes_dict = { + np.uint8: 2 ** 8 - 1, + np.uint16: 2 ** 16 - 1, + np.float: 1, + np.bool: 255, + } + for key, val in dtypes_dict.items(): + img = np.array([0, 1], dtype=key) + img = np.dstack((img,) * 3) + fig = px.imshow(img) + assert fig.data[0]["zmax"] == (val, val, val, 1) + + +def test_origin(): + for img in [img_rgb, img_gray]: + fig = px.imshow(img, origin="lower") + assert fig.layout.yaxis.autorange == True + fig = px.imshow(img_rgb) + assert fig.layout.yaxis.autorange is None + fig = px.imshow(img_gray) + assert fig.layout.yaxis.autorange == "reversed" + + +def test_colorscale(): + fig = px.imshow(img_gray) + assert fig.data[0].colorscale[0] == (0.0, "rgb(0, 0, 0)") + fig = px.imshow(img_gray, colorscale="Viridis") + assert fig.data[0].colorscale[0] == (0.0, "#440154") + + +def test_wrong_dimensions(): + imgs = [1, np.ones((5,) * 3), np.ones((5,) * 4)] + for img in imgs: + with pytest.raises(ValueError) as err_msg: + fig = px.imshow(img) + + +def test_nan_inf_data(): + imgs = [np.ones((20, 20)), 255 * np.ones((20, 20), dtype=np.uint8)] + zmaxs = [1, 255] + for zmax, img in zip(zmaxs, imgs): + img[0] = 0 + img[10:12] = np.nan + # the case of 2d/heatmap is handled gracefully by the JS trace but I don't know how to check it + fig = px.imshow(np.dstack((img,) * 3)) + assert fig.data[0]["zmax"] == (zmax, zmax, zmax, 1) + + +def test_zmax_floats(): + # RGB + imgs = [ + np.ones((5, 5, 3)), + 1.02 * np.ones((5, 5, 3)), + 2 * np.ones((5, 5, 3)), + 1000 * np.ones((5, 5, 3)), + ] + zmaxs = [1, 1, 255, 65535] + for zmax, img in zip(zmaxs, imgs): + fig = px.imshow(img) + assert fig.data[0]["zmax"] == (zmax, zmax, zmax, 1) + # single-channel + imgs = [ + np.ones((5, 5)), + 1.02 * np.ones((5, 5)), + 2 * np.ones((5, 5)), + 1000 * np.ones((5, 5)), + ] + for zmax, img in zip(zmaxs, imgs): + fig = px.imshow(img) + print(fig.data[0]["zmax"], zmax) + assert fig.data[0]["zmax"] == None