Skip to content

Commit

Permalink
BUG: Preserve original dtype if originally complex_int16 (#359)
Browse files Browse the repository at this point in the history
  • Loading branch information
snowman2 authored Jun 24, 2021
1 parent 00ff99c commit 199c051
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 42 deletions.
3 changes: 2 additions & 1 deletion rioxarray/_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -896,7 +896,7 @@ def open_rasterio(
unsigned = variables.pop_to(attrs, encoding, "_Unsigned") == "true"

if masked:
encoding["dtype"] = _rasterio_to_numpy_dtype(riods.dtypes)
encoding["dtype"] = str(_rasterio_to_numpy_dtype(riods.dtypes))

da_name = attrs.pop("NETCDF_VARNAME", default_name)
data = indexing.LazilyOuterIndexedArray(
Expand Down Expand Up @@ -950,4 +950,5 @@ def open_rasterio(
result.rio._manager = manager
# add file path to encoding
result.encoding["source"] = riods.name
result.encoding["rasterio_dtype"] = str(riods.dtypes[0])
return result
38 changes: 6 additions & 32 deletions rioxarray/raster_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
"""
import copy
import warnings
from distutils.version import LooseVersion
from typing import Iterable

Expand All @@ -31,7 +30,12 @@
OneDimensionalRaster,
RioXarrayError,
)
from rioxarray.raster_writer import FILL_VALUE_NAMES, UNWANTED_RIO_ATTRS, RasterioWriter
from rioxarray.raster_writer import (
FILL_VALUE_NAMES,
UNWANTED_RIO_ATTRS,
RasterioWriter,
_ensure_nodata_dtype,
)
from rioxarray.rioxarray import XRasterBase, _get_data_var_message, _make_coords


Expand Down Expand Up @@ -114,26 +118,6 @@ def _make_dst_affine(
return dst_affine, dst_width, dst_height


def _ensure_nodata_dtype(original_nodata, new_dtype):
"""
Convert the nodata to the new datatype and raise warning
if the value of the nodata value changed.
"""
# Complex-valued rasters can have real-valued nodata
if str(new_dtype).startswith("c"):
nodata = original_nodata
else:
original_nodata = float(original_nodata)
nodata = np.dtype(new_dtype).type(original_nodata)
if not np.isnan(nodata) and original_nodata != nodata:
warnings.warn(
f"The nodata value ({original_nodata}) has been automatically "
f"changed to ({nodata}) to match the dtype of the data."
)

return nodata


def _clip_from_disk(xds, geometries, all_touched, drop, invert):
"""
clip from disk if the file object is available
Expand Down Expand Up @@ -918,11 +902,6 @@ def to_raster(
if driver is None and LooseVersion(rasterio.__version__) < LooseVersion("1.2"):
driver = "GTiff"

dtype = (
self._obj.encoding.get("dtype", str(self._obj.dtype))
if dtype is None
else dtype
)
# get the output profile from the rasterio object
# if opened with xarray.open_rasterio()
try:
Expand Down Expand Up @@ -950,11 +929,6 @@ def to_raster(
rio_nodata = (
self.encoded_nodata if self.encoded_nodata is not None else self.nodata
)
if rio_nodata is not None:
# Ensure dtype of output data matches the expected dtype.
# This check is added here as the dtype of the data is
# converted right before writing.
rio_nodata = _ensure_nodata_dtype(rio_nodata, dtype)

return RasterioWriter(raster_path=raster_path).to_raster(
xarray_dataarray=self._obj,
Expand Down
78 changes: 74 additions & 4 deletions rioxarray/raster_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
- https://github.com/dymaxionlabs/dask-rasterio/blob/8dd7fdece7ad094a41908c0ae6b4fe6ca49cf5e1/dask_rasterio/write.py # noqa: E501
"""
import warnings

import numpy
import rasterio
from rasterio.windows import Window
from xarray.conventions import encode_cf_variable
Expand Down Expand Up @@ -90,6 +93,66 @@ def _write_metatata_to_raster(raster_handle, xarray_dataset, tags):
raster_handle.set_band_description(iii + 1, band_description)


def _ensure_nodata_dtype(original_nodata, new_dtype):
"""
Convert the nodata to the new datatype and raise warning
if the value of the nodata value changed.
"""
# Complex-valued rasters can have real-valued nodata
if str(new_dtype).startswith("c"):
nodata = original_nodata
else:
original_nodata = float(original_nodata)
nodata = numpy.dtype(new_dtype).type(original_nodata)
if not numpy.isnan(nodata) and original_nodata != nodata:
warnings.warn(
f"The nodata value ({original_nodata}) has been automatically "
f"changed to ({nodata}) to match the dtype of the data."
)

return nodata


def _get_dtypes(rasterio_dtype, encoded_rasterio_dtype, dataarray_dtype):
"""
Determines the rasterio dtype and numpy dtypes based on
the rasterio dtype and the encoded rasterio dtype.
Parameters
----------
rasterio_dtype: Union[str, numpy.dtype]
The rasterio dtype to write to.
encoded_rasterio_dtype: Union[str, numpy.dtype, None]
The value of the original rasterio dtype in the encoding.
dataarray_dtype: Union[str, numpy.dtype]
The value of the dtype of the data array.
Returns
-------
Tuple[Union[str, numpy.dtype], Union[str, numpy.dtype]]:
The rasterio dtype and numpy dtype.
"""
# SCENARIO 1: User wants to write to complex_int16
if rasterio_dtype == "complex_int16":
numpy_dtype = "complex64"
# SCENARIO 2: File originally in complext_int16 and dtype unchanged
elif (
rasterio_dtype is None
and encoded_rasterio_dtype == "complex_int16"
and str(dataarray_dtype) == "complex64"
):
numpy_dtype = "complex64"
rasterio_dtype = "complex_int16"
# SCENARIO 3: rasterio dtype not provided
elif rasterio_dtype is None:
numpy_dtype = dataarray_dtype
rasterio_dtype = dataarray_dtype
# SCENARIO 4: rasterio dtype and numpy dtype are the same
else:
numpy_dtype = rasterio_dtype
return rasterio_dtype, numpy_dtype


class RasterioWriter:
"""
Expand Down Expand Up @@ -159,10 +222,17 @@ def to_raster(self, xarray_dataarray, tags, windowed, lock, compute, **kwargs):
**kwargs
Keyword arguments to pass into writing the raster.
"""
if str(kwargs["dtype"]) == "complex_int16":
numpy_dtype = "complex64"
else:
numpy_dtype = kwargs["dtype"]
kwargs["dtype"], numpy_dtype = _get_dtypes(
kwargs["dtype"],
xarray_dataarray.encoding.get("rasterio_dtype"),
xarray_dataarray.encoding.get("dtype", str(xarray_dataarray.dtype)),
)

if kwargs["nodata"] is not None:
# Ensure dtype of output data matches the expected dtype.
# This check is added here as the dtype of the data is
# converted right before writing.
kwargs["nodata"] = _ensure_nodata_dtype(kwargs["nodata"], numpy_dtype)

with rasterio.open(self.raster_path, "w", **kwargs) as rds:
_write_metatata_to_raster(rds, xarray_dataarray, tags)
Expand Down
18 changes: 13 additions & 5 deletions test/integration/test_integration__io.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,7 @@ def test_open_rasterio_mask_chunk_clip():
"_FillValue": 0.0,
"grid_mapping": "spatial_ref",
"dtype": "uint16",
"rasterio_dtype": "uint16",
}
attrs = dict(xdi.attrs)
assert_almost_equal(
Expand Down Expand Up @@ -317,6 +318,7 @@ def test_open_rasterio_mask_chunk_clip():
"_FillValue": 0.0,
"grid_mapping": "spatial_ref",
"dtype": "uint16",
"rasterio_dtype": "uint16",
}

# test dataset
Expand All @@ -331,6 +333,7 @@ def test_open_rasterio_mask_chunk_clip():
"_FillValue": 0.0,
"grid_mapping": "spatial_ref",
"dtype": "uint16",
"rasterio_dtype": "uint16",
}


Expand Down Expand Up @@ -942,6 +945,7 @@ def test_mask_and_scale(open_rasterio):
"missing_value": 32767,
"grid_mapping": "crs",
"dtype": "uint16",
"rasterio_dtype": "uint16",
}
attrs = rds.air_temperature.attrs
assert attrs == {
Expand Down Expand Up @@ -973,6 +977,7 @@ def test_no_mask_and_scale(open_rasterio):
"missing_value": 32767,
"grid_mapping": "crs",
"dtype": "uint16",
"rasterio_dtype": "uint16",
}
attrs = rds.air_temperature.attrs
assert attrs == {
Expand Down Expand Up @@ -1113,15 +1118,17 @@ def test_non_rectilinear__skip_parse_coordinates(open_rasterio):
rasterio.__version__ < "1.2.4",
reason="https://github.com/mapbox/rasterio/issues/2182",
)
def test_cint16_dtype(tmp_path):
@pytest.mark.parametrize("dtype", [None, "complex_int16"])
def test_cint16_dtype(dtype, tmp_path):
test_file = os.path.join(TEST_INPUT_DATA_DIR, "cint16.tif")
xds = rioxarray.open_rasterio(test_file)
assert xds.rio.shape == (100, 100)
assert xds.dtype == "complex64"
assert xds.encoding["rasterio_dtype"] == "complex_int16"

tmp_output = tmp_path / "tmp_cint16.tif"
with pytest.warns(NotGeoreferencedWarning):
xds.rio.to_raster(str(tmp_output), dtype="complex_int16")
xds.rio.to_raster(str(tmp_output), dtype=dtype)
with rasterio.open(str(tmp_output)) as riofh:
data = riofh.read()
assert "complex_int16" in riofh.dtypes
Expand Down Expand Up @@ -1152,7 +1159,8 @@ def test_cint16_dtype_nodata(tmp_path):
assert riofh.nodata is None


def test_cint16_dtype_masked(tmp_path):
@pytest.mark.parametrize("dtype", [None, "complex_int16"])
def test_cint16_dtype_masked(dtype, tmp_path):
test_file = os.path.join(TEST_INPUT_DATA_DIR, "cint16.tif")
xds = rioxarray.open_rasterio(test_file, masked=True)
assert xds.rio.shape == (100, 100)
Expand All @@ -1162,7 +1170,7 @@ def test_cint16_dtype_masked(tmp_path):

tmp_output = tmp_path / "tmp_cint16.tif"
with pytest.warns(NotGeoreferencedWarning):
xds.rio.to_raster(str(tmp_output), dtype="complex_int16")
xds.rio.to_raster(str(tmp_output), dtype=dtype)
with rasterio.open(str(tmp_output)) as riofh:
data = riofh.read()
assert "complex_int16" in riofh.dtypes
Expand All @@ -1176,7 +1184,7 @@ def test_cint16_promote_dtype(tmp_path):

tmp_output = tmp_path / "tmp_cfloat64.tif"
with pytest.warns(NotGeoreferencedWarning):
xds.rio.to_raster(str(tmp_output))
xds.rio.to_raster(str(tmp_output), dtype="complex64")
with rasterio.open(str(tmp_output)) as riofh:
data = riofh.read()
assert "complex64" in riofh.dtypes
Expand Down

0 comments on commit 199c051

Please sign in to comment.