diff --git a/kedro-datasets/RELEASE.md b/kedro-datasets/RELEASE.md index 9bd3cb0ff..48d607efd 100755 --- a/kedro-datasets/RELEASE.md +++ b/kedro-datasets/RELEASE.md @@ -9,10 +9,13 @@ | `langchain.ChatCohereDataset` | A dataset for loading a ChatCohere langchain model. | `kedro_datasets_experimental.langchain` | | `langchain.OpenAIEmbeddingsDataset` | A dataset for loading a OpenAIEmbeddings langchain model. | `kedro_datasets_experimental.langchain` | | `langchain.ChatOpenAIDataset` | A dataset for loading a ChatOpenAI langchain model. | `kedro_datasets_experimental.langchain` | +| `rioxarray.GeoTIFFDataset` | A dataset for loading and saving geotiff raster data | `kedro_datasets_experimental.rioxarray` | | `netcdf.NetCDFDataset` | A dataset for loading and saving "*.nc" files. | `kedro_datasets_experimental.netcdf` | + * `netcdf.NetCDFDataset` moved from `kedro_datasets` to `kedro_datasets_experimental`. * Added the following new core datasets: + | Type | Description | Location | |-------------------------------------|-----------------------------------------------------------|-----------------------------------------| | `dask.CSVDataset` | A dataset for loading a CSV files using `dask` | `kedro_datasets.dask` | @@ -22,6 +25,9 @@ ## Community contributions Many thanks to the following Kedroids for contributing PRs to this release: +* [Ian Whalen](https://github.com/ianwhale) +* [Charles Guan](https://github.com/charlesbmi) +* [Thomas Gölles](https://github.com/tgoelles) * [Lukas Innig](https://github.com/derluke) * [Michael Sexton](https://github.com/michaelsexton) diff --git a/kedro-datasets/docs/source/api/kedro_datasets_experimental.rst b/kedro-datasets/docs/source/api/kedro_datasets_experimental.rst index fbae09589..34cb3caf8 100644 --- a/kedro-datasets/docs/source/api/kedro_datasets_experimental.rst +++ b/kedro-datasets/docs/source/api/kedro_datasets_experimental.rst @@ -16,3 +16,4 @@ kedro_datasets_experimental kedro_datasets_experimental.langchain.ChatOpenAIDataset kedro_datasets_experimental.langchain.OpenAIEmbeddingsDataset kedro_datasets_experimental.netcdf.NetCDFDataset + kedro_datasets_experimental.rioxarray.GeoTIFFDataset diff --git a/kedro-datasets/kedro_datasets_experimental/rioxarray/__init__.py b/kedro-datasets/kedro_datasets_experimental/rioxarray/__init__.py new file mode 100644 index 000000000..b1f52ce01 --- /dev/null +++ b/kedro-datasets/kedro_datasets_experimental/rioxarray/__init__.py @@ -0,0 +1,13 @@ +"""``AbstractDataset`` implementation to load/save data from/to a geospatial raster files.""" +from __future__ import annotations + +from typing import Any + +import lazy_loader as lazy + +# https://github.com/pylint-dev/pylint/issues/4300#issuecomment-1043601901 +GeoTIFFDataset: Any + +__getattr__, __dir__, __all__ = lazy.attach( + __name__, submod_attrs={"geotiff_dataset": ["GeoTIFFDataset"]} +) diff --git a/kedro-datasets/kedro_datasets_experimental/rioxarray/geotiff_dataset.py b/kedro-datasets/kedro_datasets_experimental/rioxarray/geotiff_dataset.py new file mode 100644 index 000000000..b69dea574 --- /dev/null +++ b/kedro-datasets/kedro_datasets_experimental/rioxarray/geotiff_dataset.py @@ -0,0 +1,209 @@ +"""GeoTIFFDataset loads geospatial raster data and saves it to a local geoiff file. The +underlying functionality is supported by rioxarray and xarray. A read rasterdata file +returns a xarray.DataArray object. +""" +import logging +from copy import deepcopy +from pathlib import PurePosixPath +from typing import Any + +import fsspec +import rasterio +import rioxarray as rxr +import xarray +from kedro.io import AbstractVersionedDataset, DatasetError +from kedro.io.core import Version, get_filepath_str, get_protocol_and_path +from rasterio.crs import CRS +from rasterio.transform import from_bounds + +logger = logging.getLogger(__name__) + +SUPPORTED_DIMS = [("band", "x", "y"), ("x", "y")] +DEFAULT_NO_DATA_VALUE = -9999 +SUPPORTED_FILE_FORMATS = [".tif", ".tiff"] + + +class GeoTIFFDataset(AbstractVersionedDataset[xarray.DataArray, xarray.DataArray]): + """``GeoTIFFDataset`` loads and saves rasterdata files and reads them as xarray + DataArrays. The underlying functionality is supported by rioxarray, rasterio and xarray. + + Reading and writing of single and multiband GeoTIFFs data is supported. There are sanity checks to ensure that a coordinate reference system (CRS) is present. + Supported dimensions are ("band", "x", "y") and ("x", "y") and xarray.DataArray with other dimension can not be saved to a GeoTIFF file. + Have a look at netcdf if this is what you need. + + + .. code-block:: yaml + + sentinal_data: + type: rioxarray.GeoTIFFDataset + filepath: sentinal_data.tif + + Example usage for the + `Python API `_: + + .. code-block:: pycon + + >>> from kedro_datasets.rioxarray import GeoTIFFDataset + >>> import xarray as xr + >>> import numpy as np + >>> + >>> data = xr.DataArray( + ... np.random.randn(2, 3, 2), + ... dims=("band", "y", "x"), + ... coords={"band": [1, 2], "y": [0.5, 1.5, 2.5], "x": [0.5, 1.5]} + ... ) + >>> data_crs = data.rio.write_crs("epsg:4326") + >>> data_spatial_dims = data_crs.rio.set_spatial_dims("x", "y") + >>> dataset = GeoTIFFDataset(filepath="test.tif") + >>> dataset.save(data_spatial_dims) + >>> reloaded = dataset.load() + >>> xr.testing.assert_allclose(data_spatial_dims, reloaded, rtol=1e-5) + + """ + + DEFAULT_LOAD_ARGS: dict[str, Any] = {} + DEFAULT_SAVE_ARGS: dict[str, Any] = {} + + def __init__( # noqa: PLR0913 + self, + *, + filepath: str, + load_args: dict[str, Any] | None = None, + save_args: dict[str, Any] | None = None, + version: Version | None = None, + metadata: dict[str, Any] | None = None, + ): + """Creates a new instance of ``GeoTIFFDataset`` pointing to a concrete + geospatial raster data file. + + + Args: + filepath: Filepath in POSIX format to a rasterdata file. + The prefix should be any protocol supported by ``fsspec``. + load_args: rioxarray options for loading rasterdata files. + Here you can find all available arguments: + https://corteva.github.io/rioxarray/html/rioxarray.html#rioxarray-open-rasterio + All defaults are preserved. + save_args: options for rioxarray for data without the band dimension and rasterio otherwhise. + version: If specified, should be an instance of + ``kedro.io.core.Version``. If its ``load`` attribute is + None, the latest version will be loaded. If its ``save`` + attribute is None, save version will be autogenerated. + metadata: Any arbitrary metadata. + This is ignored by Kedro, but may be consumed by users or external plugins. + """ + protocol, path = get_protocol_and_path(filepath, version) + self._protocol = protocol + self._fs = fsspec.filesystem(self._protocol) + self.metadata = metadata + + super().__init__( + filepath=PurePosixPath(path), + version=version, + exists_function=self._fs.exists, + glob_function=self._fs.glob, + ) + + # Handle default load and save arguments + self._load_args = deepcopy(self.DEFAULT_LOAD_ARGS) + if load_args is not None: + self._load_args.update(load_args) + self._save_args = deepcopy(self.DEFAULT_SAVE_ARGS) + if save_args is not None: + self._save_args.update(save_args) + + def _describe(self) -> dict[str, Any]: + return { + "filepath": self._filepath, + "protocol": self._protocol, + "load_args": self._load_args, + "save_args": self._save_args, + "version": self._version, + } + + def _load(self) -> xarray.DataArray: + load_path = self._get_load_path().as_posix() + with rasterio.open(load_path) as data: + tags = data.tags() + data = rxr.open_rasterio(load_path, **self._load_args) + data.attrs.update(tags) + self._sanity_check(data) + logger.info(f"found coordinate rerence system {data.rio.crs}") + return data + + def _save(self, data: xarray.DataArray) -> None: + self._sanity_check(data) + save_path = get_filepath_str(self._get_save_path(), self._protocol) + if not save_path.endswith(tuple(SUPPORTED_FILE_FORMATS)): + raise ValueError( + f"Unsupported file format. Supported formats are: {SUPPORTED_FILE_FORMATS}" + ) + if "band" in data.dims: + self._save_multiband(data, save_path) + else: + data.rio.to_raster(save_path, **self._save_args) + self._fs.invalidate_cache(save_path) + + def _exists(self) -> bool: + try: + load_path = get_filepath_str(self._get_load_path(), self._protocol) + except DatasetError: + return False + + return self._fs.exists(load_path) + + def _release(self) -> None: + super()._release() + self._invalidate_cache() + + def _invalidate_cache(self) -> None: + """Invalidate underlying filesystem caches.""" + filepath = get_filepath_str(self._filepath, self._protocol) + self._fs.invalidate_cache(filepath) + + def _save_multiband(self, data: xarray.DataArray, save_path: str): + """Saving multiband raster data to a geotiff file.""" + bands_data = [data.sel(band=band) for band in data.band.values] + transform = from_bounds( + west=data.x.min(), + south=data.y.min(), + east=data.x.max(), + north=data.y.max(), + width=data[0].shape[1], + height=data[0].shape[0], + ) + + nodata_value = ( + data.rio.nodata if data.rio.nodata is not None else DEFAULT_NO_DATA_VALUE + ) + crs = data.rio.crs + + meta = { + "driver": "GTiff", + "height": bands_data[0].shape[0], + "width": bands_data[0].shape[1], + "count": len(bands_data), + "dtype": str(bands_data[0].dtype), + "crs": crs, + "transform": transform, + "nodata": nodata_value, + } + with rasterio.open(save_path, "w", **meta) as dst: + for idx, band in enumerate(bands_data, start=1): + dst.write(band.data, idx, **self._save_args) + + def _sanity_check(self, data: xarray.DataArray) -> None: + """Perform sanity checks on the data to ensure it meets the requirements.""" + if not isinstance(data, xarray.DataArray): + raise NotImplementedError( + "Currently only supporting xarray.DataArray while saving raster data." + ) + + if not isinstance(data.rio.crs, CRS): + raise ValueError("Dataset lacks a coordinate reference system.") + + if all(set(data.dims) != set(dims) for dims in SUPPORTED_DIMS): + raise ValueError( + f"Data has unsupported dimensions: {data.dims}. Supported dimensions are: {SUPPORTED_DIMS}" + ) diff --git a/kedro-datasets/kedro_datasets_experimental/tests/rioxarray/__init__.py b/kedro-datasets/kedro_datasets_experimental/tests/rioxarray/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/kedro-datasets/kedro_datasets_experimental/tests/rioxarray/cog.tif b/kedro-datasets/kedro_datasets_experimental/tests/rioxarray/cog.tif new file mode 100644 index 000000000..e2bc24a1c Binary files /dev/null and b/kedro-datasets/kedro_datasets_experimental/tests/rioxarray/cog.tif differ diff --git a/kedro-datasets/kedro_datasets_experimental/tests/rioxarray/test_geotiff_dataset.py b/kedro-datasets/kedro_datasets_experimental/tests/rioxarray/test_geotiff_dataset.py new file mode 100644 index 000000000..7f217eee6 --- /dev/null +++ b/kedro-datasets/kedro_datasets_experimental/tests/rioxarray/test_geotiff_dataset.py @@ -0,0 +1,181 @@ +from pathlib import Path + +import numpy as np +import pytest +import rasterio +import xarray as xr +from kedro.io import DatasetError +from rasterio.crs import CRS + +from kedro_datasets_experimental.rioxarray.geotiff_dataset import GeoTIFFDataset + + +@pytest.fixture +def cog_file_path() -> str: + cog_file_path = Path(__file__).parent / "cog.tif" + return cog_file_path.as_posix() + +@pytest.fixture +def multi1_file_path() -> str: + path = Path(__file__).parent / "test_multi1.tif" + return path.as_posix() + +@pytest.fixture +def multi2_file_path() -> str: + path = Path(__file__).parent / "test_multi2.tif" + return path.as_posix() + +@pytest.fixture +def synthetic_xarray(): + """Create a synthetic xarray.DataArray with CRS information.""" + data = xr.DataArray( + np.random.rand(100, 100), + dims=("y", "x"), + coords={"x": np.linspace(0, 100, 100), "y": np.linspace(0, 100, 100)} + ) + data.rio.write_crs("epsg:4326", inplace=True) + return data + +@pytest.fixture +def synthetic_xarray_multiband(): + """Create a synthetic xarray.DataArray with CRS information.""" + data = xr.DataArray( + np.random.rand(10, 100, 100), + dims=("band", "y", "x"), + coords={"x": np.linspace(0, 100, 100), "y": np.linspace(0, 100, 100)} + ) + data.rio.write_crs("epsg:4326", inplace=True) + return data + +@pytest.fixture +def synthetic_xarray_many_vars_no_band(): + """Create a synthetic xarray.DataArray with CRS information.""" + data = xr.DataArray( + np.random.rand(2,3,4, 100, 100), + dims=("var1","var2","var3","y", "x"), + coords={"x": np.linspace(0, 100, 100), "y": np.linspace(0, 100, 100)} + ) + data.rio.write_crs("epsg:4326", inplace=True) + return data + + +@pytest.fixture +def cog_geotiff_dataset(cog_file_path, save_args) -> GeoTIFFDataset: + return GeoTIFFDataset(filepath=cog_file_path, save_args=save_args) + + +def test_load_cog_geotiff(cog_geotiff_dataset): + """Test loading cloud optimised geotiff reloading the data set.""" + loaded_xr = cog_geotiff_dataset.load() + assert isinstance(loaded_xr.rio.crs, CRS) + assert isinstance(loaded_xr, xr.DataArray) + assert loaded_xr.shape == (1, 500, 500) + assert loaded_xr.dims == ("band", "y", "x") + +def test_load_save_cog(tmp_path,cog_file_path): + """Test loading a multiband raster file.""" + dataset = GeoTIFFDataset(filepath=cog_file_path) + loaded_xr = dataset.load() + band1_data = loaded_xr.sel(band=1) + target_file = tmp_path / "tmp22.tif" + dataset_to = GeoTIFFDataset(filepath=str(target_file)) + dataset_to.save(loaded_xr) + reloaded_xr = dataset_to.load() + assert target_file.exists() + assert isinstance(loaded_xr.rio.crs, CRS) + assert isinstance(loaded_xr, xr.DataArray) + assert len(loaded_xr.band) == 1 + assert loaded_xr.dims == ("band", "y", "x") + assert loaded_xr.shape == (1, 500, 500) + assert np.isclose(band1_data.values.std(), 4688.72624578268) + assert (loaded_xr.values == reloaded_xr.values).all() + + + +def test_load_save_multi1(tmp_path,multi1_file_path): + """Test loading a multiband raster file.""" + dataset = GeoTIFFDataset(filepath=multi1_file_path) + dataset_to = GeoTIFFDataset(filepath=str(tmp_path / "tmp.tif")) + loaded_xr = dataset.load() + band1_data = loaded_xr.sel(band=1) + assert isinstance(loaded_xr.rio.crs, CRS) + assert isinstance(loaded_xr, xr.DataArray) + BAND_COUNT = 2 + assert len(loaded_xr.band) == BAND_COUNT + assert loaded_xr.shape == (BAND_COUNT, 5, 5) + assert loaded_xr.dims == ("band", "y", "x") + assert np.isclose(band1_data.values.std(), 0.015918046) + dataset_to.save(loaded_xr) + reloaded_xr = dataset_to.load() + assert (loaded_xr.values == reloaded_xr.values).all() + +def test_load_geotiff_with_tags(tmp_path, synthetic_xarray): + filepath = tmp_path / "test_with_tags.tif" + tags = {"TAG_KEY": "TAG_VALUE", "ANOTHER_TAG": "ANOTHER_VALUE"} + with rasterio.open( + filepath, "w", driver="GTiff", height=100, width=100, count=1, dtype=str(synthetic_xarray.dtype), + crs="EPSG:4326" + ) as dst: + dst.write(synthetic_xarray.values, 1) + dst.update_tags(**tags) + + dataset = GeoTIFFDataset(filepath=str(filepath)) + loaded_xr = dataset.load() + + assert loaded_xr.attrs["TAG_KEY"] == "TAG_VALUE" + assert loaded_xr.attrs["ANOTHER_TAG"] == "ANOTHER_VALUE" + + assert isinstance(loaded_xr, xr.DataArray) + assert isinstance(loaded_xr.rio.crs, CRS) + assert loaded_xr.shape == (1, 100, 100) + +def test_load_no_crs(multi2_file_path): + """Test loading a multiband raster file.""" + dataset = GeoTIFFDataset(filepath=multi2_file_path) + with pytest.raises(DatasetError): + dataset.load() + +def test_load_not_tif(): + """Test loading a multiband raster file.""" + dataset = GeoTIFFDataset(filepath="whatever.nc") + with pytest.raises(DatasetError): + dataset.load() + + +def test_exists(tmp_path, synthetic_xarray): + """Test `exists` method invocation for both existing and + nonexistent data set.""" + dataset = GeoTIFFDataset(filepath=str(tmp_path / "tmp.tif")) + assert not dataset.exists() + dataset.save(synthetic_xarray) + assert dataset.exists() + +@pytest.mark.parametrize("xarray_fixture", [ + "synthetic_xarray_multiband", + "synthetic_xarray", +]) +def test_save_and_load_geotiff(tmp_path, request, xarray_fixture): + """Test saving and reloading the data set.""" + xarray_data = request.getfixturevalue(xarray_fixture) + dataset = GeoTIFFDataset(filepath=str(tmp_path / "tmp.tif")) + dataset.save(xarray_data) + assert dataset.exists() + reloaded_xr = dataset.load() + assert isinstance(reloaded_xr, xr.DataArray) + assert isinstance(reloaded_xr.rio.crs, CRS) + assert reloaded_xr.dims == ("band", "y", "x") + assert (xarray_data.values == reloaded_xr.values).all() + +def test_save_and_load_geotiff_no_band(tmp_path, synthetic_xarray_many_vars_no_band): + """this test should fail because the data array has no band dimension""" + dataset = GeoTIFFDataset(filepath=str(tmp_path / "tmp.tif")) + with pytest.raises(DatasetError): + dataset.save(synthetic_xarray_many_vars_no_band) + +def test_load_missing_file(tmp_path): + """Check the error when trying to load missing file.""" + dataset = GeoTIFFDataset(filepath=str(tmp_path / "tmp.tif")) + assert not dataset._exists(), "File unexpectedly exists" + pattern = r"Failed while loading data from data set GeoTIFFDataset\(.*\)" + with pytest.raises(DatasetError, match=pattern): + dataset.load() diff --git a/kedro-datasets/kedro_datasets_experimental/tests/rioxarray/test_multi1.tif b/kedro-datasets/kedro_datasets_experimental/tests/rioxarray/test_multi1.tif new file mode 100644 index 000000000..bfc0c6a2c Binary files /dev/null and b/kedro-datasets/kedro_datasets_experimental/tests/rioxarray/test_multi1.tif differ diff --git a/kedro-datasets/kedro_datasets_experimental/tests/rioxarray/test_multi2.tif b/kedro-datasets/kedro_datasets_experimental/tests/rioxarray/test_multi2.tif new file mode 100644 index 000000000..6dfbedb6a Binary files /dev/null and b/kedro-datasets/kedro_datasets_experimental/tests/rioxarray/test_multi2.tif differ diff --git a/kedro-datasets/pyproject.toml b/kedro-datasets/pyproject.toml index 01a70d27f..ac0a3a17d 100644 --- a/kedro-datasets/pyproject.toml +++ b/kedro-datasets/pyproject.toml @@ -169,9 +169,13 @@ langchain-openaiembeddingsdataset = ["langchain-openai~=0.1.7"] langchain-chatanthropicdataset = ["langchain-anthropic~=0.1.13", "langchain-community~=0.2.0"] langchain-chatcoheredataset = ["langchain-cohere~=0.1.5", "langchain-community~=0.2.0"] langchain = ["kedro-datasets[langchain-chatopenaidataset,langchain-openaiembeddingsdataset,langchain-chatanthropicdataset,langchain-chatcoheredataset ]"] + netcdf-netcdfdataset = ["h5netcdf>=1.2.0","netcdf4>=1.6.4","xarray>=2023.1.0"] netcdf = ["kedro-datasets[netcdf-netcdfdataset]"] +rioxarray-geotiffdataset = ["rioxarray>=0.15.0"] +rioxarray = ["kedro-datasets[rioxarray-geotiffdataset]"] + # Docs requirements docs = [ "kedro-sphinx-theme==2024.4.0", @@ -274,6 +278,7 @@ experimental = [ "h5netcdf>=1.2.0", "netcdf4>=1.6.4", "xarray>=2023.1.0", + "rioxarray", ] # All requirements