Skip to content

Commit

Permalink
feat(datasets): Added PyTorchDataset (kedro-org#735)
Browse files Browse the repository at this point in the history
* Added PyTorchDataset

Signed-off-by: bpmeek <bpmeek.developer@gmail.com>

* updated RELEASE.md

Signed-off-by: bpmeek <bpmeek.developer@gmail.com>

* Add dependencies for PyTorchDataset

Signed-off-by: Merel Theisen <merel.theisen@quantumblack.com>

* Add PyTorchDataset to API docs

Signed-off-by: Merel Theisen <merel.theisen@quantumblack.com>

* Fix docs build

Signed-off-by: Merel Theisen <merel.theisen@quantumblack.com>

---------

Signed-off-by: bpmeek <bpmeek.developer@gmail.com>
Signed-off-by: Merel Theisen <49397448+merelcht@users.noreply.github.com>
Signed-off-by: Merel Theisen <merel.theisen@quantumblack.com>
Co-authored-by: Merel Theisen <49397448+merelcht@users.noreply.github.com>
Co-authored-by: Merel Theisen <merel.theisen@quantumblack.com>
Signed-off-by: Merel Theisen <merel.theisen@quantumblack.com>
  • Loading branch information
3 people committed Aug 27, 2024
1 parent 125dc37 commit 9b07173
Show file tree
Hide file tree
Showing 9 changed files with 214 additions and 2 deletions.
12 changes: 10 additions & 2 deletions kedro-datasets/RELEASE.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,17 @@
# Upcoming Release
## Major features and improvements
* Added the following new **experimental** datasets:

| Type | Description | Location |
|-------------------------------------|-----------------------------------------------------------|-----------------------------------------|
| `pytorch.PyTorchDataset` | A dataset for securely saving and loading PyTorch models | `kedro_datasets_experimental.pytorch` |

## Bug fixes and other changes
## Breaking Changes
## Community contributions
Many thanks to the following Kedroids for contributing PRs to this release:
* [Brandon Meek](https://github.com/bpmeek)


# Release 4.1.0
## Major features and improvements
Expand All @@ -27,13 +35,13 @@
| `langchain.OpenAIEmbeddingsDataset` | A dataset for loading a OpenAIEmbeddings langchain model. | `kedro_datasets_experimental.langchain` |
| `langchain.ChatOpenAIDataset` | A dataset for loading a ChatOpenAI langchain model. | `kedro_datasets_experimental.langchain` |
| `rioxarray.GeoTIFFDataset` | A dataset for loading and saving geotiff raster data | `kedro_datasets_experimental.rioxarray` |
| `netcdf.NetCDFDataset` | A dataset for loading and saving "*.nc" files. | `kedro_datasets_experimental.netcdf` |
| `netcdf.NetCDFDataset` | A dataset for loading and saving "*.nc" files. | `kedro_datasets_experimental.netcdf` |

* Added the following new core datasets:

| Type | Description | Location |
|-------------------------------------|-----------------------------------------------------------|-----------------------------------------|
| `dask.CSVDataset` | A dataset for loading a CSV files using `dask` | `kedro_datasets.dask` |
| `dask.CSVDataset` | A dataset for loading a CSV files using `dask` | `kedro_datasets.dask` |

* Extended preview feature to `yaml.YAMLDataset`.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,5 @@ kedro_datasets_experimental
langchain.ChatOpenAIDataset
langchain.OpenAIEmbeddingsDataset
netcdf.NetCDFDataset
pytorch.PyTorchDataset
rioxarray.GeoTIFFDataset
1 change: 1 addition & 0 deletions kedro-datasets/docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@
"langchain_cohere.chat_models.ChatCohere",
"xarray.core.dataset.Dataset",
"xarray.core.dataarray.DataArray",
"torch.nn.modules.module.Module",
),
"py:data": (
"typing.Any",
Expand Down
11 changes: 11 additions & 0 deletions kedro-datasets/kedro_datasets_experimental/pytorch/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
"""``AbstractDataset`` implementation to load/save torch models using PyTorch's built-in methods """

from typing import Any

import lazy_loader as lazy

PyTorchDataset: Any

__getattr__, __dir__, __all__ = lazy.attach(
__name__, submod_attrs={"pytorch_dataset": ["PyTorchDataset"]}
)
122 changes: 122 additions & 0 deletions kedro-datasets/kedro_datasets_experimental/pytorch/pytorch_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
from __future__ import annotations

from copy import deepcopy
from pathlib import PurePosixPath
from typing import Any

import fsspec
import torch
from kedro.io.core import (
AbstractVersionedDataset,
DatasetError,
Version,
get_filepath_str,
get_protocol_and_path,
)


class PyTorchDataset(AbstractVersionedDataset[Any, Any]):
"""``PyTorchDataset`` loads and saves PyTorch models' `state_dict`
using PyTorch's recommended zipfile serialization protocol. To avoid
security issues with Pickle.
.. code-block:: yaml
model:
type: pytorch.PyTorchDataset
filepath: data/06_models/model.pt
.. code-block:: pycon
>>> from kedro_datasets_experimental.pytorch import PyTorchDataset
>>> import torch
>>>
>>> model: torch.nn.Module
>>> model = torch.nn.Sequential(torch.nn.Linear(10, 10), torch.nn.ReLU())
>>> dataset = PyTorchDataset(filepath=tmp_path / "model.pt")
>>> dataset.save(model)
>>> reloaded = TheModelClass(*args, **kwargs)
>>> reloaded.load_state_dict(dataset.load())
"""

DEFAULT_LOAD_ARGS: dict[str, Any] = {}
DEFAULT_SAVE_ARGS: dict[str, Any] = {}

def __init__( # noqa: PLR0913
self,
*,
filepath,
load_args: dict[str, Any] = None,
save_args: dict[str, Any] = None,
version: Version | None = None,
credentials: dict[str, Any] = None,
fs_args: dict[str, Any] = None,
metadata: dict[str, Any] = None,
):
_fs_args = deepcopy(fs_args) or {}
_fs_open_args_load = _fs_args.pop("open_args_load", {})
_fs_open_args_save = _fs_args.pop("open_args_save", {})
_credentials = deepcopy(credentials) or {}

protocol, path = get_protocol_and_path(filepath, version)
if protocol == "file":
_fs_args.setdefault("auto_mkdir", True)

self._protocol = protocol
self._fs = fsspec.filesystem(self._protocol, **_credentials, **_fs_args)

self.metadata = metadata

super().__init__(
filepath=PurePosixPath(path),
version=version,
exists_function=self._fs.exists,
glob_function=self._fs.glob,
)

# Handle default load and save arguments
self._load_args = deepcopy(self.DEFAULT_LOAD_ARGS)
if load_args is not None:
self._load_args.update(load_args)
self._save_args = deepcopy(self.DEFAULT_SAVE_ARGS)
if save_args is not None:
self._save_args.update(save_args)

self._fs_open_args_load = _fs_open_args_load
self._fs_open_args_save = _fs_open_args_save

def _describe(self) -> dict[str, Any]:
return {
"filepath": self._filepath,
"protocol": self._protocol,
"load_args": self._load_args,
"save_args": self._save_args,
"version": self._version,
}

def _load(self) -> Any:
load_path = get_filepath_str(self._get_load_path(), self._protocol)
return torch.load(load_path, **self._fs_open_args_load)

def _save(self, data: torch.nn.Module) -> None:
save_path = get_filepath_str(self._get_save_path(), self._protocol)
torch.save(data.state_dict(), save_path, **self._fs_open_args_save)

self._invalidate_cache()

def _exists(self):
try:
load_path = get_filepath_str(self._get_load_path(), self._protocol)
except DatasetError:
return False

return self._fs.exists(load_path)

def _release(self) -> None:
super()._release()
self._invalidate_cache()

def _invalidate_cache(self) -> None:
"""Invalidate underlying filesystem caches."""
filepath = get_filepath_str(self._filepath, self._protocol)
self._fs.invalidate_cache(filepath)
4 changes: 4 additions & 0 deletions kedro-datasets/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,9 @@ langchain = ["kedro-datasets[langchain-chatopenaidataset,langchain-openaiembeddi
netcdf-netcdfdataset = ["h5netcdf>=1.2.0","netcdf4>=1.6.4","xarray>=2023.1.0"]
netcdf = ["kedro-datasets[netcdf-netcdfdataset]"]

pytorch-dataset = ["torch"]
pytorch = ["kedro-datasets[pytorch-dataset]"]

rioxarray-geotiffdataset = ["rioxarray>=0.15.0"]
rioxarray = ["kedro-datasets[rioxarray-geotiffdataset]"]

Expand Down Expand Up @@ -279,6 +282,7 @@ experimental = [
"netcdf4>=1.6.4",
"xarray>=2023.1.0",
"rioxarray",
"torch"
]

# All requirements
Expand Down
Empty file.
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import pytest
import torch
import torch.nn.functional as F
from torch import nn

from kedro_datasets_experimental.pytorch import PyTorchDataset


# Define model
class TheModelClass(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)

def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x


@pytest.fixture
def filepath_model(tmp_path):
return (tmp_path / "model.pt").as_posix()


@pytest.fixture
def pytorch_dataset(filepath_model, load_args, save_args, fs_args):
return PyTorchDataset(
filepath=filepath_model,
load_args=load_args,
save_args=save_args,
fs_args=fs_args,
)


@pytest.fixture
def versioned_pytorch_dataset(filepath_model, load_version, save_version):
return PyTorchDataset(
filepath=filepath_model, load_version=load_version, save_version=save_version
)


@pytest.fixture
def dummy_model():
return TheModelClass()


class TestPyTorchDataset:
def test_save_and_load_dataset(self, pytorch_dataset, dummy_model, filepath_model):
pytorch_dataset.save(dummy_model)
model = TheModelClass()
model.load_state_dict(pytorch_dataset.load())
reloaded_state_dict = model.state_dict()
dummy_state_dict = dummy_model.state_dict()
for key, value in reloaded_state_dict.items():
assert torch.equal(dummy_state_dict[key], value)

0 comments on commit 9b07173

Please sign in to comment.