Skip to content

Commit

Permalink
Merge pull request kedro-org#829 from quantumblacklabs/merge-master-t…
Browse files Browse the repository at this point in the history
…o-develop

Merge master into develop via merge-master-to-develop
  • Loading branch information
idanov authored Oct 20, 2020
2 parents 0c74a14 + 4c64060 commit 7e8fbee
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 2 deletions.
3 changes: 2 additions & 1 deletion RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,12 @@
* Fixed `kedro install` for an Anaconda environment defined in `environment.yml`.
* Fixed backwards compatibility with templates generated with older Kedro versions <0.16.5. No longer need to update `.kedro.yml` to use `kedro lint` and `kedro jupyter notebook convert`.
* Improved documentation.
* Fixed issue with saving a `TensorFlowModelDataset` in the HDF5 format with versioning enabled.

## Breaking changes to the API

## Thanks for supporting contributions
[Deepyaman Datta](https://github.com/deepyaman), [Bhavya Merchant](https://github.com/bnmerchant), [Lovkush Agarwal](https://github.com/Lovkush-A), [Varun Krishna S](https://github.com/vhawk19), [Sebastian Bertoli](https://github.com/sebastianbertoli), [Saran Balaji C](https://github.com/csaranbalaji)
[Deepyaman Datta](https://github.com/deepyaman), [Bhavya Merchant](https://github.com/bnmerchant), [Lovkush Agarwal](https://github.com/Lovkush-A), [Varun Krishna S](https://github.com/vhawk19), [Sebastian Bertoli](https://github.com/sebastianbertoli), [Daniel Petti](https://github.com/djpetti), [Saran Balaji C](https://github.com/csaranbalaji)

# Release 0.16.5

Expand Down
6 changes: 5 additions & 1 deletion kedro/extras/datasets/tensorflow/tensorflow_model_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
"""
import copy
import tempfile
from pathlib import PurePath, PurePosixPath
from pathlib import Path, PurePath, PurePosixPath
from typing import Any, Dict

import fsspec
Expand Down Expand Up @@ -151,6 +151,10 @@ def _load(self) -> tf.keras.Model:
def _save(self, data: tf.keras.Model) -> None:
save_path = get_filepath_str(self._get_save_path(), self._protocol)

# Make sure all intermediate directories are created.
save_dir = Path(save_path).parent
save_dir.mkdir(parents=True, exist_ok=True)

with tempfile.TemporaryDirectory(prefix=self._tmp_prefix) as path:
if self._is_h5:
path = str(PurePath(path) / TEMPORARY_H5_FILE)
Expand Down
24 changes: 24 additions & 0 deletions tests/extras/datasets/tensorflow/test_tensorflow_model_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,30 @@ def test_save_and_load(
new_predictions = reloaded.predict(dummy_x_test)
np.testing.assert_allclose(predictions, new_predictions, rtol=1e-6, atol=1e-6)

def test_hdf5_save_format(
self,
dummy_tf_base_model,
dummy_x_test,
filepath,
tensorflow_model_dataset,
load_version,
save_version,
):
"""Test versioned TensorflowModelDataset can save TF graph models in
HDF5 format"""
hdf5_dataset = tensorflow_model_dataset(
filepath=filepath,
save_args={"save_format": "h5"},
version=Version(load_version, save_version),
)

predictions = dummy_tf_base_model.predict(dummy_x_test)
hdf5_dataset.save(dummy_tf_base_model)

reloaded = hdf5_dataset.load()
new_predictions = reloaded.predict(dummy_x_test)
np.testing.assert_allclose(predictions, new_predictions, rtol=1e-6, atol=1e-6)

def test_prevent_overwrite(self, dummy_tf_base_model, versioned_tf_model_dataset):
"""Check the error when attempting to override the data set if the
corresponding file for a given save version already exists."""
Expand Down

0 comments on commit 7e8fbee

Please sign in to comment.