diff --git a/kedro-datasets/RELEASE.md b/kedro-datasets/RELEASE.md index a90176166..d007db7d3 100755 --- a/kedro-datasets/RELEASE.md +++ b/kedro-datasets/RELEASE.md @@ -1,9 +1,17 @@ # Upcoming Release ## Major features and improvements +* Added the following new **experimental** datasets: + +| Type | Description | Location | +|-------------------------------------|-----------------------------------------------------------|-----------------------------------------| +| `pytorch.PyTorchDataset` | A dataset for securely saving and loading PyTorch models | `kedro_datasets_experimental.pytorch` | ## Bug fixes and other changes ## Breaking Changes ## Community contributions +Many thanks to the following Kedroids for contributing PRs to this release: +* [Brandon Meek](https://github.com/bpmeek) + # Release 4.1.0 ## Major features and improvements @@ -27,13 +35,13 @@ | `langchain.OpenAIEmbeddingsDataset` | A dataset for loading a OpenAIEmbeddings langchain model. | `kedro_datasets_experimental.langchain` | | `langchain.ChatOpenAIDataset` | A dataset for loading a ChatOpenAI langchain model. | `kedro_datasets_experimental.langchain` | | `rioxarray.GeoTIFFDataset` | A dataset for loading and saving geotiff raster data | `kedro_datasets_experimental.rioxarray` | -| `netcdf.NetCDFDataset` | A dataset for loading and saving "*.nc" files. | `kedro_datasets_experimental.netcdf` | +| `netcdf.NetCDFDataset` | A dataset for loading and saving "*.nc" files. | `kedro_datasets_experimental.netcdf` | * Added the following new core datasets: | Type | Description | Location | |-------------------------------------|-----------------------------------------------------------|-----------------------------------------| -| `dask.CSVDataset` | A dataset for loading a CSV files using `dask` | `kedro_datasets.dask` | +| `dask.CSVDataset` | A dataset for loading a CSV files using `dask` | `kedro_datasets.dask` | * Extended preview feature to `yaml.YAMLDataset`. diff --git a/kedro-datasets/docs/source/api/kedro_datasets_experimental.rst b/kedro-datasets/docs/source/api/kedro_datasets_experimental.rst index c6e443564..0eb76c739 100644 --- a/kedro-datasets/docs/source/api/kedro_datasets_experimental.rst +++ b/kedro-datasets/docs/source/api/kedro_datasets_experimental.rst @@ -16,4 +16,5 @@ kedro_datasets_experimental langchain.ChatOpenAIDataset langchain.OpenAIEmbeddingsDataset netcdf.NetCDFDataset + pytorch.PyTorchDataset rioxarray.GeoTIFFDataset diff --git a/kedro-datasets/docs/source/conf.py b/kedro-datasets/docs/source/conf.py index ef46c8401..70c6be3ae 100644 --- a/kedro-datasets/docs/source/conf.py +++ b/kedro-datasets/docs/source/conf.py @@ -139,6 +139,7 @@ "langchain_cohere.chat_models.ChatCohere", "xarray.core.dataset.Dataset", "xarray.core.dataarray.DataArray", + "torch.nn.modules.module.Module", ), "py:data": ( "typing.Any", diff --git a/kedro-datasets/kedro_datasets_experimental/pytorch/__init__.py b/kedro-datasets/kedro_datasets_experimental/pytorch/__init__.py new file mode 100644 index 000000000..1165a0311 --- /dev/null +++ b/kedro-datasets/kedro_datasets_experimental/pytorch/__init__.py @@ -0,0 +1,11 @@ +"""``AbstractDataset`` implementation to load/save torch models using PyTorch's built-in methods """ + +from typing import Any + +import lazy_loader as lazy + +PyTorchDataset: Any + +__getattr__, __dir__, __all__ = lazy.attach( + __name__, submod_attrs={"pytorch_dataset": ["PyTorchDataset"]} +) diff --git a/kedro-datasets/kedro_datasets_experimental/pytorch/pytorch_dataset.py b/kedro-datasets/kedro_datasets_experimental/pytorch/pytorch_dataset.py new file mode 100644 index 000000000..914fdb6b7 --- /dev/null +++ b/kedro-datasets/kedro_datasets_experimental/pytorch/pytorch_dataset.py @@ -0,0 +1,122 @@ +from __future__ import annotations + +from copy import deepcopy +from pathlib import PurePosixPath +from typing import Any + +import fsspec +import torch +from kedro.io.core import ( + AbstractVersionedDataset, + DatasetError, + Version, + get_filepath_str, + get_protocol_and_path, +) + + +class PyTorchDataset(AbstractVersionedDataset[Any, Any]): + """``PyTorchDataset`` loads and saves PyTorch models' `state_dict` + using PyTorch's recommended zipfile serialization protocol. To avoid + security issues with Pickle. + + .. code-block:: yaml + + model: + type: pytorch.PyTorchDataset + filepath: data/06_models/model.pt + + .. code-block:: pycon + + >>> from kedro_datasets_experimental.pytorch import PyTorchDataset + >>> import torch + >>> + >>> model: torch.nn.Module + >>> model = torch.nn.Sequential(torch.nn.Linear(10, 10), torch.nn.ReLU()) + >>> dataset = PyTorchDataset(filepath=tmp_path / "model.pt") + >>> dataset.save(model) + >>> reloaded = TheModelClass(*args, **kwargs) + >>> reloaded.load_state_dict(dataset.load()) + """ + + DEFAULT_LOAD_ARGS: dict[str, Any] = {} + DEFAULT_SAVE_ARGS: dict[str, Any] = {} + + def __init__( # noqa: PLR0913 + self, + *, + filepath, + load_args: dict[str, Any] = None, + save_args: dict[str, Any] = None, + version: Version | None = None, + credentials: dict[str, Any] = None, + fs_args: dict[str, Any] = None, + metadata: dict[str, Any] = None, + ): + _fs_args = deepcopy(fs_args) or {} + _fs_open_args_load = _fs_args.pop("open_args_load", {}) + _fs_open_args_save = _fs_args.pop("open_args_save", {}) + _credentials = deepcopy(credentials) or {} + + protocol, path = get_protocol_and_path(filepath, version) + if protocol == "file": + _fs_args.setdefault("auto_mkdir", True) + + self._protocol = protocol + self._fs = fsspec.filesystem(self._protocol, **_credentials, **_fs_args) + + self.metadata = metadata + + super().__init__( + filepath=PurePosixPath(path), + version=version, + exists_function=self._fs.exists, + glob_function=self._fs.glob, + ) + + # Handle default load and save arguments + self._load_args = deepcopy(self.DEFAULT_LOAD_ARGS) + if load_args is not None: + self._load_args.update(load_args) + self._save_args = deepcopy(self.DEFAULT_SAVE_ARGS) + if save_args is not None: + self._save_args.update(save_args) + + self._fs_open_args_load = _fs_open_args_load + self._fs_open_args_save = _fs_open_args_save + + def _describe(self) -> dict[str, Any]: + return { + "filepath": self._filepath, + "protocol": self._protocol, + "load_args": self._load_args, + "save_args": self._save_args, + "version": self._version, + } + + def _load(self) -> Any: + load_path = get_filepath_str(self._get_load_path(), self._protocol) + return torch.load(load_path, **self._fs_open_args_load) + + def _save(self, data: torch.nn.Module) -> None: + save_path = get_filepath_str(self._get_save_path(), self._protocol) + torch.save(data.state_dict(), save_path, **self._fs_open_args_save) + + self._invalidate_cache() + + def _exists(self): + try: + load_path = get_filepath_str(self._get_load_path(), self._protocol) + except DatasetError: + return False + + return self._fs.exists(load_path) + + def _release(self) -> None: + super()._release() + self._invalidate_cache() + + def _invalidate_cache(self) -> None: + """Invalidate underlying filesystem caches.""" + filepath = get_filepath_str(self._filepath, self._protocol) + self._fs.invalidate_cache(filepath) diff --git a/kedro-datasets/pyproject.toml b/kedro-datasets/pyproject.toml index 5c7b19733..e20054573 100644 --- a/kedro-datasets/pyproject.toml +++ b/kedro-datasets/pyproject.toml @@ -173,6 +173,9 @@ langchain = ["kedro-datasets[langchain-chatopenaidataset,langchain-openaiembeddi netcdf-netcdfdataset = ["h5netcdf>=1.2.0","netcdf4>=1.6.4","xarray>=2023.1.0"] netcdf = ["kedro-datasets[netcdf-netcdfdataset]"] +pytorch-dataset = ["torch"] +pytorch = ["kedro-datasets[pytorch-dataset]"] + rioxarray-geotiffdataset = ["rioxarray>=0.15.0"] rioxarray = ["kedro-datasets[rioxarray-geotiffdataset]"] @@ -279,6 +282,7 @@ experimental = [ "netcdf4>=1.6.4", "xarray>=2023.1.0", "rioxarray", + "torch" ] # All requirements diff --git a/kedro-datasets/tests/kedro_datasets_experimental/__init__.py b/kedro-datasets/tests/kedro_datasets_experimental/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/kedro-datasets/tests/kedro_datasets_experimental/pytorch/__init__.py b/kedro-datasets/tests/kedro_datasets_experimental/pytorch/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/kedro-datasets/tests/kedro_datasets_experimental/pytorch/test_pytorch_dataset.py b/kedro-datasets/tests/kedro_datasets_experimental/pytorch/test_pytorch_dataset.py new file mode 100644 index 000000000..cf04b0a2d --- /dev/null +++ b/kedro-datasets/tests/kedro_datasets_experimental/pytorch/test_pytorch_dataset.py @@ -0,0 +1,65 @@ +import pytest +import torch +import torch.nn.functional as F +from torch import nn + +from kedro_datasets_experimental.pytorch import PyTorchDataset + + +# Define model +class TheModelClass(nn.Module): + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(3, 6, 5) + self.pool = nn.MaxPool2d(2, 2) + self.conv2 = nn.Conv2d(6, 16, 5) + self.fc1 = nn.Linear(16 * 5 * 5, 120) + self.fc2 = nn.Linear(120, 84) + self.fc3 = nn.Linear(84, 10) + + def forward(self, x): + x = self.pool(F.relu(self.conv1(x))) + x = self.pool(F.relu(self.conv2(x))) + x = x.view(-1, 16 * 5 * 5) + x = F.relu(self.fc1(x)) + x = F.relu(self.fc2(x)) + x = self.fc3(x) + return x + + +@pytest.fixture +def filepath_model(tmp_path): + return (tmp_path / "model.pt").as_posix() + + +@pytest.fixture +def pytorch_dataset(filepath_model, load_args, save_args, fs_args): + return PyTorchDataset( + filepath=filepath_model, + load_args=load_args, + save_args=save_args, + fs_args=fs_args, + ) + + +@pytest.fixture +def versioned_pytorch_dataset(filepath_model, load_version, save_version): + return PyTorchDataset( + filepath=filepath_model, load_version=load_version, save_version=save_version + ) + + +@pytest.fixture +def dummy_model(): + return TheModelClass() + + +class TestPyTorchDataset: + def test_save_and_load_dataset(self, pytorch_dataset, dummy_model, filepath_model): + pytorch_dataset.save(dummy_model) + model = TheModelClass() + model.load_state_dict(pytorch_dataset.load()) + reloaded_state_dict = model.state_dict() + dummy_state_dict = dummy_model.state_dict() + for key, value in reloaded_state_dict.items(): + assert torch.equal(dummy_state_dict[key], value)