From df576357b790d1bf2cf8d5f6917fae267b87cb21 Mon Sep 17 00:00:00 2001 From: grougrou Date: Sun, 20 Jun 2021 08:45:12 +0200 Subject: [PATCH 1/3] fix dtype complex for rasterio backend --- xarray/backends/rasterio_.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/xarray/backends/rasterio_.py b/xarray/backends/rasterio_.py index 49a5a9ec7ae..d12b4446a1c 100644 --- a/xarray/backends/rasterio_.py +++ b/xarray/backends/rasterio_.py @@ -39,7 +39,10 @@ def __init__(self, manager, lock, vrt_params=None): dtypes = riods.dtypes if not np.all(np.asarray(dtypes) == dtypes[0]): raise ValueError("All bands should have the same dtype") - self._dtype = np.dtype(dtypes[0]) + if dtypes[0]!='complex_int16': + self._dtype = np.dtype(dtypes[0]) + else: + self._dtype = np.complex @property def dtype(self): From 6ca31cbd631fec43f28968a99ef360d49da8ed51 Mon Sep 17 00:00:00 2001 From: Antoine Grouazel Date: Thu, 24 Jun 2021 13:35:48 +0200 Subject: [PATCH 2/3] Update xarray/backends/rasterio_.py Co-authored-by: keewis --- xarray/backends/rasterio_.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/xarray/backends/rasterio_.py b/xarray/backends/rasterio_.py index d12b4446a1c..a1a72a24299 100644 --- a/xarray/backends/rasterio_.py +++ b/xarray/backends/rasterio_.py @@ -39,10 +39,10 @@ def __init__(self, manager, lock, vrt_params=None): dtypes = riods.dtypes if not np.all(np.asarray(dtypes) == dtypes[0]): raise ValueError("All bands should have the same dtype") - if dtypes[0]!='complex_int16': - self._dtype = np.dtype(dtypes[0]) - else: + if dtypes[0] == "complex_int16": self._dtype = np.complex + else: + self._dtype = np.dtype(dtypes[0]) @property def dtype(self): From 674d9b69fc7cd99c63c8d10764fd900442ec3be2 Mon Sep 17 00:00:00 2001 From: grouazel Date: Wed, 25 Aug 2021 16:30:45 +0200 Subject: [PATCH 3/3] add backend test --- xarray/backends/rasterio_.py | 6 +++--- xarray/tests/test_backends.py | 17 +++++++++++++++-- 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/xarray/backends/rasterio_.py b/xarray/backends/rasterio_.py index a1a72a24299..3da40688a47 100644 --- a/xarray/backends/rasterio_.py +++ b/xarray/backends/rasterio_.py @@ -39,10 +39,10 @@ def __init__(self, manager, lock, vrt_params=None): dtypes = riods.dtypes if not np.all(np.asarray(dtypes) == dtypes[0]): raise ValueError("All bands should have the same dtype") - if dtypes[0] == "complex_int16": - self._dtype = np.complex - else: + if dtypes[0] != 'complex_int16' : self._dtype = np.dtype(dtypes[0]) + else : + self._dtype = complex @property def dtype(self): diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index 5079cd390f1..675026eff75 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -4137,6 +4137,7 @@ def create_tmp_geotiff( crs=default_value, open_kwargs=None, additional_attrs=None, + specific_dtype=None ): if transform_args is default_value: transform_args = [5000, 80000, 1000, 2000.0] @@ -4163,7 +4164,11 @@ def create_tmp_geotiff( else: data_shape = nz, ny, nx write_kwargs = {} - data = np.arange(nz * ny * nx, dtype=rasterio.float32).reshape(*data_shape) + if specific_dtype is None: + specific_dtype = rasterio.float32 + data = np.arange(nz * ny * nx, dtype=rasterio.float32).reshape(*data_shape) + else: + data = np.arange(nz * ny * nx,dtype=rasterio.float32).reshape(*data_shape) if transform is None: transform = from_origin(*transform_args) if additional_attrs is None: @@ -4180,7 +4185,7 @@ def create_tmp_geotiff( count=nz, crs=crs, transform=transform, - dtype=rasterio.float32, + dtype=specific_dtype, **open_kwargs, ) as s: for attr, val in additional_attrs.items(): @@ -4697,6 +4702,14 @@ def test_rasterio_vrt_network(self): assert actual_res == expected_res assert expected_val == actual_val + def test_rasterio_complex_dtype( self ): + import rasterio + with create_tmp_geotiff(specific_dtype='complex_int16', + ) as (tmp_file, _): + with rasterio.open(tmp_file) as riobj: + assert riobj.dtypes[0]=='complex_int16' + with xr.open_rasterio(tmp_file) as rioda: + assert rioda.dtype==complex class TestEncodingInvalid: def test_extract_nc4_variable_encoding(self):