Skip to content

Commit

Permalink
Allow passing RGB xarray.DataArray images into grdimage (#2590)
Browse files Browse the repository at this point in the history
Saving 3-band xarray.DataArray images to a temporary GeoTIFF
so that they can be plotted with grdimage.

* Refactor to use tempfile_from_image

Putting the temporary GeoTIFF creation logic in a dedicated
tempfile_from_image helper function, so that it can be reused
by other GMT modules besides grdimage. Also ensure that an
ImportError is raised when the `.rio` attribute cannot be found
when rioxarray is not installed.

* Let tilemap use tempfile_from_image func in virtualfile_from_data

Refactor Figure.tilemap to use the same tempfile_from_image
function that generates a temporary GeoTIFF file from the 3-band
xarray.DataArray images.

* Update docstring of grdimage with upstream GMT 6.4.0

Various updates from upstream GMT at
GenericMappingTools/gmt#6258,
GenericMappingTools/gmt@9069967,
GenericMappingTools/gmt#7260.

* Raise RuntimeWarning when input image dtype is not uint8

Plotting a non-uint8 dtype xarray.DataArray works in grdimage,
but the results will likely be incorrect. Warning the user about
the incorrect dtype, and suggest recasting to uint8 with 0-255 range,
e.g. using a histogram equalization function like
skimage.exposure.equalize_hist.

---------

Co-authored-by: Dongdong Tian <seisman.info@gmail.com>
  • Loading branch information
weiji14 and seisman authored Aug 8, 2023
1 parent a671fb0 commit cf022a2
Show file tree
Hide file tree
Showing 8 changed files with 161 additions and 33 deletions.
22 changes: 19 additions & 3 deletions pygmt/clib/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import ctypes as ctp
import pathlib
import sys
import warnings
from contextlib import contextmanager, nullcontext

import numpy as np
Expand All @@ -26,7 +27,12 @@
GMTInvalidInput,
GMTVersionError,
)
from pygmt.helpers import data_kind, fmt_docstring, tempfile_from_geojson
from pygmt.helpers import (
data_kind,
fmt_docstring,
tempfile_from_geojson,
tempfile_from_image,
)

FAMILIES = [
"GMT_IS_DATASET", # Entity is a data table
Expand Down Expand Up @@ -1540,7 +1546,7 @@ def virtualfile_from_data(
if check_kind:
valid_kinds = ("file", "arg") if required_data is False else ("file",)
if check_kind == "raster":
valid_kinds += ("grid",)
valid_kinds += ("grid", "image")
elif check_kind == "vector":
valid_kinds += ("matrix", "vectors", "geojson")
if kind not in valid_kinds:
Expand All @@ -1554,6 +1560,7 @@ def virtualfile_from_data(
"arg": nullcontext,
"geojson": tempfile_from_geojson,
"grid": self.virtualfile_from_grid,
"image": tempfile_from_image,
# Note: virtualfile_from_matrix is not used because a matrix can be
# converted to vectors instead, and using vectors allows for better
# handling of string type inputs (e.g. for datetime data types)
Expand All @@ -1562,7 +1569,16 @@ def virtualfile_from_data(
}[kind]

# Ensure the data is an iterable (Python list or tuple)
if kind in ("geojson", "grid", "file", "arg"):
if kind in ("geojson", "grid", "image", "file", "arg"):
if kind == "image" and data.dtype != "uint8":
msg = (
f"Input image has dtype: {data.dtype} which is unsupported, "
"and may result in an incorrect output. Please recast image "
"to a uint8 dtype and/or scale to 0-255 range, e.g. "
"using a histogram equalization function like "
"skimage.exposure.equalize_hist."
)
warnings.warn(message=msg, category=RuntimeWarning, stacklevel=2)
_data = (data,) if not isinstance(data, pathlib.PurePath) else (str(data),)
elif kind == "vectors":
_data = [np.atleast_1d(x), np.atleast_1d(y)]
Expand Down
7 changes: 6 additions & 1 deletion pygmt/helpers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,12 @@
kwargs_to_strings,
use_alias,
)
from pygmt.helpers.tempfile import GMTTempFile, tempfile_from_geojson, unique_name
from pygmt.helpers.tempfile import (
GMTTempFile,
tempfile_from_geojson,
tempfile_from_image,
unique_name,
)
from pygmt.helpers.utils import (
args_in_kwargs,
build_arg_string,
Expand Down
31 changes: 31 additions & 0 deletions pygmt/helpers/tempfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,3 +147,34 @@ def tempfile_from_geojson(geojson):
geoseries.to_file(**ogrgmt_kwargs)

yield tmpfile.name


@contextmanager
def tempfile_from_image(image):
"""
Saves a 3-band :class:`xarray.DataArray` to a temporary GeoTIFF file via
rioxarray.
Parameters
----------
image : xarray.DataArray
An xarray.DataArray with three dimensions, having a shape like
(3, Y, X).
Yields
------
tmpfilename : str
A temporary GeoTIFF file holding the image data. E.g. '1a2b3c4d5.tif'.
"""
with GMTTempFile(suffix=".tif") as tmpfile:
os.remove(tmpfile.name) # ensure file is deleted first
try:
image.rio.to_raster(raster_path=tmpfile.name)
except AttributeError as e: # object has no attribute 'rio'
raise ImportError(
"Package `rioxarray` is required to be installed to use this function. "
"Please use `python -m pip install rioxarray` or "
"`mamba install -c conda-forge rioxarray` "
"to install the package."
) from e
yield tmpfile.name
8 changes: 5 additions & 3 deletions pygmt/helpers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,8 @@ def data_kind(data=None, x=None, y=None, z=None, required_z=False, required_data
Returns
-------
kind : str
One of ``'arg'``, ``'file'``, ``'grid'``, ``'geojson'``, ``'matrix'``,
or ``'vectors'``.
One of ``'arg'``, ``'file'``, ``'grid'``, ``image``, ``'geojson'``,
``'matrix'``, or ``'vectors'``.
Examples
--------
Expand All @@ -166,14 +166,16 @@ def data_kind(data=None, x=None, y=None, z=None, required_z=False, required_data
'arg'
>>> data_kind(data=xr.DataArray(np.random.rand(4, 3)))
'grid'
>>> data_kind(data=xr.DataArray(np.random.rand(3, 4, 5)))
'image'
"""
# determine the data kind
if isinstance(data, (str, pathlib.PurePath)):
kind = "file"
elif isinstance(data, (bool, int, float)) or (data is None and not required_data):
kind = "arg"
elif isinstance(data, xr.DataArray):
kind = "grid"
kind = "image" if len(data.dims) == 3 else "grid"
elif hasattr(data, "__geo_interface__"):
# geo-like Python object that implements ``__geo_interface__``
# (geopandas.GeoDataFrame or shapely.geometry)
Expand Down
27 changes: 12 additions & 15 deletions pygmt/src/grdimage.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,11 @@ def grdimage(self, grid, **kwargs):
instructions to derive intensities from the input data grid. Values outside
this range will be clipped. Such intensity files can be created from the
grid using :func:`pygmt.grdgradient` and, optionally, modified by
:gmt-docs:`grdmath.html` or :class:`pygmt.grdhisteq`. If GMT is built
with GDAL support, ``grid`` can be an image file (geo-referenced or not).
In this case the image can optionally be illuminated with the file
provided via the ``shading`` parameter. Here, if image has no coordinates
then those of the intensity file will be used.
:gmt-docs:`grdmath.html` or :class:`pygmt.grdhisteq`. Alternatively, pass
*image* which can be an image file (geo-referenced or not). In this case
the image can optionally be illuminated with the file provided via the
``shading`` parameter. Here, if image has no coordinates then those of the
intensity file will be used.
When using map projections, the grid is first resampled on a new
rectangular grid with the same dimensions. Higher resolution images can
Expand Down Expand Up @@ -74,10 +74,7 @@ def grdimage(self, grid, **kwargs):
:gmt-docs:`grdimage.html#grid-file-formats`).
img_out : str
*out_img*\[=\ *driver*].
Save an image in a raster format instead of PostScript. Use
extension .ppm for a Portable Pixel Map format which is the only
raster format GMT can natively write. For GMT installations
configured with GDAL support there are more choices: Append
Save an image in a raster format instead of PostScript. Append
*out_img* to select the image file name and extension. If the
extension is one of .bmp, .gif, .jpg, .png, or .tif then no driver
information is required. For other output formats you must append
Expand Down Expand Up @@ -131,8 +128,8 @@ def grdimage(self, grid, **kwargs):
:func:`pygmt.grdgradient` separately first. If we should derive
intensities from another file than grid, specify the file with
suitable modifiers [Default is no illumination]. **Note**: If the
input data is an *image* then an *intensfile* or constant *intensity*
must be provided.
input data represent an *image* then an *intensfile* or constant
*intensity* must be provided.
{projection}
monochrome : bool
Force conversion to monochrome image using the (television) YIQ
Expand All @@ -144,10 +141,9 @@ def grdimage(self, grid, **kwargs):
[**+z**\ *value*][*color*]
Make grid nodes with z = NaN transparent, using the color-masking
feature in PostScript Level 3 (the PS device must support PS Level
3). If the input is a grid, use **+z** with a *value* to select
another grid value than NaN. If the input is instead an image,
append an alternate *color* to select another pixel value to be
transparent [Default is ``"black"``].
3). If the input is a grid, use **+z** to select another grid value
than NaN. If input is instead an image, append an alternate *color* to
select another pixel value to be transparent [Default is ``"black"``].
{region}
{verbose}
{panel}
Expand All @@ -171,6 +167,7 @@ def grdimage(self, grid, **kwargs):
>>> fig.show()
"""
kwargs = self._preprocess(**kwargs) # pylint: disable=protected-access

with Session() as lib:
with lib.virtualfile_from_data(
check_kind="raster", data=grid
Expand Down
16 changes: 5 additions & 11 deletions pygmt/src/tilemap.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,7 @@
"""
from pygmt.clib import Session
from pygmt.datasets.tile_map import load_tile_map
from pygmt.helpers import (
GMTTempFile,
build_arg_string,
fmt_docstring,
kwargs_to_strings,
use_alias,
)
from pygmt.helpers import build_arg_string, fmt_docstring, kwargs_to_strings, use_alias

try:
import rioxarray
Expand Down Expand Up @@ -148,9 +142,9 @@ def tilemap(
if kwargs.get("N") in [None, False]:
kwargs["R"] = "/".join(str(coordinate) for coordinate in region)

with GMTTempFile(suffix=".tif") as tmpfile:
raster.rio.to_raster(raster_path=tmpfile.name)
with Session() as lib:
with Session() as lib:
file_context = lib.virtualfile_from_data(check_kind="raster", data=raster)
with file_context as infile:
lib.call_module(
module="grdimage", args=build_arg_string(kwargs, infile=tmpfile.name)
module="grdimage", args=build_arg_string(kwargs, infile=infile)
)
4 changes: 4 additions & 0 deletions pygmt/tests/baseline/test_grdimage_image.png.dvc
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
outs:
- md5: 2e919645d5af956ec4f8aa054a86a70a
size: 110214
path: test_grdimage_image.png
79 changes: 79 additions & 0 deletions pygmt/tests/test_grdimage_image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
"""
Test Figure.grdimage on 3-band RGB images.
"""
import numpy as np
import pandas as pd
import pytest
import xarray as xr
from pygmt import Figure, which

rasterio = pytest.importorskip("rasterio")
rioxarray = pytest.importorskip("rioxarray")


@pytest.fixture(scope="module", name="xr_image")
def fixture_xr_image():
"""
Load the image data from Blue Marble as an xarray.DataArray with shape
{"band": 3, "y": 180, "x": 360}.
"""
geotiff = which(fname="@earth_day_01d_p", download="c")
with rioxarray.open_rasterio(filename=geotiff) as rda:
if len(rda.band) == 1:
with rasterio.open(fp=geotiff) as src:
df_colormap = pd.DataFrame.from_dict(
data=src.colormap(1), orient="index"
)
array = src.read()

red = np.vectorize(df_colormap[0].get)(array)
green = np.vectorize(df_colormap[1].get)(array)
blue = np.vectorize(df_colormap[2].get)(array)
# alpha = np.vectorize(df_colormap[3].get)(array)

rda.data = red
da_red = rda.astype(dtype=np.uint8).copy()
rda.data = green
da_green = rda.astype(dtype=np.uint8).copy()
rda.data = blue
da_blue = rda.astype(dtype=np.uint8).copy()

xr_image = xr.concat(objs=[da_red, da_green, da_blue], dim="band")
assert xr_image.sizes == {"band": 3, "y": 180, "x": 360}
return xr_image


@pytest.mark.mpl_image_compare
def test_grdimage_image():
"""
Plot a 3-band RGB image using file input.
"""
fig = Figure()
fig.grdimage(grid="@earth_day_01d")
return fig


@pytest.mark.mpl_image_compare(filename="test_grdimage_image.png")
def test_grdimage_image_dataarray(xr_image):
"""
Plot a 3-band RGB image using xarray.DataArray input.
"""
fig = Figure()
fig.grdimage(grid=xr_image)
return fig


@pytest.mark.parametrize(
"dtype",
["int8", "uint16", "int16", "uint32", "int32", "float32", "float64"],
)
def test_grdimage_image_dataarray_unsupported_dtype(dtype, xr_image):
"""
Plot a 3-band RGB image using xarray.DataArray input, with an unsupported
data type.
"""
fig = Figure()
image = xr_image.astype(dtype=dtype)
with pytest.warns(expected_warning=RuntimeWarning) as record:
fig.grdimage(grid=image)
assert len(record) == 1

0 comments on commit cf022a2

Please sign in to comment.