-
Notifications
You must be signed in to change notification settings - Fork 89
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(datasets): Added PyTorchDataset (#735)
* Added PyTorchDataset Signed-off-by: bpmeek <bpmeek.developer@gmail.com> * updated RELEASE.md Signed-off-by: bpmeek <bpmeek.developer@gmail.com> * Add dependencies for PyTorchDataset Signed-off-by: Merel Theisen <merel.theisen@quantumblack.com> * Add PyTorchDataset to API docs Signed-off-by: Merel Theisen <merel.theisen@quantumblack.com> * Fix docs build Signed-off-by: Merel Theisen <merel.theisen@quantumblack.com> --------- Signed-off-by: bpmeek <bpmeek.developer@gmail.com> Signed-off-by: Merel Theisen <49397448+merelcht@users.noreply.github.com> Signed-off-by: Merel Theisen <merel.theisen@quantumblack.com> Co-authored-by: Merel Theisen <49397448+merelcht@users.noreply.github.com> Co-authored-by: Merel Theisen <merel.theisen@quantumblack.com>
- Loading branch information
1 parent
994f86c
commit 36524d6
Showing
9 changed files
with
214 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
11 changes: 11 additions & 0 deletions
11
kedro-datasets/kedro_datasets_experimental/pytorch/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
"""``AbstractDataset`` implementation to load/save torch models using PyTorch's built-in methods """ | ||
|
||
from typing import Any | ||
|
||
import lazy_loader as lazy | ||
|
||
PyTorchDataset: Any | ||
|
||
__getattr__, __dir__, __all__ = lazy.attach( | ||
__name__, submod_attrs={"pytorch_dataset": ["PyTorchDataset"]} | ||
) |
122 changes: 122 additions & 0 deletions
122
kedro-datasets/kedro_datasets_experimental/pytorch/pytorch_dataset.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
from __future__ import annotations | ||
|
||
from copy import deepcopy | ||
from pathlib import PurePosixPath | ||
from typing import Any | ||
|
||
import fsspec | ||
import torch | ||
from kedro.io.core import ( | ||
AbstractVersionedDataset, | ||
DatasetError, | ||
Version, | ||
get_filepath_str, | ||
get_protocol_and_path, | ||
) | ||
|
||
|
||
class PyTorchDataset(AbstractVersionedDataset[Any, Any]): | ||
"""``PyTorchDataset`` loads and saves PyTorch models' `state_dict` | ||
using PyTorch's recommended zipfile serialization protocol. To avoid | ||
security issues with Pickle. | ||
.. code-block:: yaml | ||
model: | ||
type: pytorch.PyTorchDataset | ||
filepath: data/06_models/model.pt | ||
.. code-block:: pycon | ||
>>> from kedro_datasets_experimental.pytorch import PyTorchDataset | ||
>>> import torch | ||
>>> | ||
>>> model: torch.nn.Module | ||
>>> model = torch.nn.Sequential(torch.nn.Linear(10, 10), torch.nn.ReLU()) | ||
>>> dataset = PyTorchDataset(filepath=tmp_path / "model.pt") | ||
>>> dataset.save(model) | ||
>>> reloaded = TheModelClass(*args, **kwargs) | ||
>>> reloaded.load_state_dict(dataset.load()) | ||
""" | ||
|
||
DEFAULT_LOAD_ARGS: dict[str, Any] = {} | ||
DEFAULT_SAVE_ARGS: dict[str, Any] = {} | ||
|
||
def __init__( # noqa: PLR0913 | ||
self, | ||
*, | ||
filepath, | ||
load_args: dict[str, Any] = None, | ||
save_args: dict[str, Any] = None, | ||
version: Version | None = None, | ||
credentials: dict[str, Any] = None, | ||
fs_args: dict[str, Any] = None, | ||
metadata: dict[str, Any] = None, | ||
): | ||
_fs_args = deepcopy(fs_args) or {} | ||
_fs_open_args_load = _fs_args.pop("open_args_load", {}) | ||
_fs_open_args_save = _fs_args.pop("open_args_save", {}) | ||
_credentials = deepcopy(credentials) or {} | ||
|
||
protocol, path = get_protocol_and_path(filepath, version) | ||
if protocol == "file": | ||
_fs_args.setdefault("auto_mkdir", True) | ||
|
||
self._protocol = protocol | ||
self._fs = fsspec.filesystem(self._protocol, **_credentials, **_fs_args) | ||
|
||
self.metadata = metadata | ||
|
||
super().__init__( | ||
filepath=PurePosixPath(path), | ||
version=version, | ||
exists_function=self._fs.exists, | ||
glob_function=self._fs.glob, | ||
) | ||
|
||
# Handle default load and save arguments | ||
self._load_args = deepcopy(self.DEFAULT_LOAD_ARGS) | ||
if load_args is not None: | ||
self._load_args.update(load_args) | ||
self._save_args = deepcopy(self.DEFAULT_SAVE_ARGS) | ||
if save_args is not None: | ||
self._save_args.update(save_args) | ||
|
||
self._fs_open_args_load = _fs_open_args_load | ||
self._fs_open_args_save = _fs_open_args_save | ||
|
||
def _describe(self) -> dict[str, Any]: | ||
return { | ||
"filepath": self._filepath, | ||
"protocol": self._protocol, | ||
"load_args": self._load_args, | ||
"save_args": self._save_args, | ||
"version": self._version, | ||
} | ||
|
||
def _load(self) -> Any: | ||
load_path = get_filepath_str(self._get_load_path(), self._protocol) | ||
return torch.load(load_path, **self._fs_open_args_load) | ||
|
||
def _save(self, data: torch.nn.Module) -> None: | ||
save_path = get_filepath_str(self._get_save_path(), self._protocol) | ||
torch.save(data.state_dict(), save_path, **self._fs_open_args_save) | ||
|
||
self._invalidate_cache() | ||
|
||
def _exists(self): | ||
try: | ||
load_path = get_filepath_str(self._get_load_path(), self._protocol) | ||
except DatasetError: | ||
return False | ||
|
||
return self._fs.exists(load_path) | ||
|
||
def _release(self) -> None: | ||
super()._release() | ||
self._invalidate_cache() | ||
|
||
def _invalidate_cache(self) -> None: | ||
"""Invalidate underlying filesystem caches.""" | ||
filepath = get_filepath_str(self._filepath, self._protocol) | ||
self._fs.invalidate_cache(filepath) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
Empty file.
65 changes: 65 additions & 0 deletions
65
kedro-datasets/tests/kedro_datasets_experimental/pytorch/test_pytorch_dataset.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
import pytest | ||
import torch | ||
import torch.nn.functional as F | ||
from torch import nn | ||
|
||
from kedro_datasets_experimental.pytorch import PyTorchDataset | ||
|
||
|
||
# Define model | ||
class TheModelClass(nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.conv1 = nn.Conv2d(3, 6, 5) | ||
self.pool = nn.MaxPool2d(2, 2) | ||
self.conv2 = nn.Conv2d(6, 16, 5) | ||
self.fc1 = nn.Linear(16 * 5 * 5, 120) | ||
self.fc2 = nn.Linear(120, 84) | ||
self.fc3 = nn.Linear(84, 10) | ||
|
||
def forward(self, x): | ||
x = self.pool(F.relu(self.conv1(x))) | ||
x = self.pool(F.relu(self.conv2(x))) | ||
x = x.view(-1, 16 * 5 * 5) | ||
x = F.relu(self.fc1(x)) | ||
x = F.relu(self.fc2(x)) | ||
x = self.fc3(x) | ||
return x | ||
|
||
|
||
@pytest.fixture | ||
def filepath_model(tmp_path): | ||
return (tmp_path / "model.pt").as_posix() | ||
|
||
|
||
@pytest.fixture | ||
def pytorch_dataset(filepath_model, load_args, save_args, fs_args): | ||
return PyTorchDataset( | ||
filepath=filepath_model, | ||
load_args=load_args, | ||
save_args=save_args, | ||
fs_args=fs_args, | ||
) | ||
|
||
|
||
@pytest.fixture | ||
def versioned_pytorch_dataset(filepath_model, load_version, save_version): | ||
return PyTorchDataset( | ||
filepath=filepath_model, load_version=load_version, save_version=save_version | ||
) | ||
|
||
|
||
@pytest.fixture | ||
def dummy_model(): | ||
return TheModelClass() | ||
|
||
|
||
class TestPyTorchDataset: | ||
def test_save_and_load_dataset(self, pytorch_dataset, dummy_model, filepath_model): | ||
pytorch_dataset.save(dummy_model) | ||
model = TheModelClass() | ||
model.load_state_dict(pytorch_dataset.load()) | ||
reloaded_state_dict = model.state_dict() | ||
dummy_state_dict = dummy_model.state_dict() | ||
for key, value in reloaded_state_dict.items(): | ||
assert torch.equal(dummy_state_dict[key], value) |