forked from kedro-org/kedro-plugins
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Implement
plotly.HTMLDataset
(kedro-org#788)
* Implement `plotly.HTMLDataset` Signed-off-by: Yury Fedotov <yury_fedotov@mckinsey.com> Signed-off-by: Merel Theisen <merel.theisen@quantumblack.com>
- Loading branch information
1 parent
abeafa4
commit 20f7ffc
Showing
6 changed files
with
255 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,154 @@ | ||
"""``HTMLDataset`` saves a plotly figure to an HTML file using an underlying | ||
filesystem (e.g.: local, S3, GCS). | ||
""" | ||
from __future__ import annotations | ||
|
||
from copy import deepcopy | ||
from pathlib import PurePosixPath | ||
from typing import Any, NoReturn, Union | ||
|
||
import fsspec | ||
from kedro.io.core import ( | ||
AbstractVersionedDataset, | ||
DatasetError, | ||
Version, | ||
get_filepath_str, | ||
get_protocol_and_path, | ||
) | ||
from plotly import graph_objects as go | ||
|
||
|
||
class HTMLDataset( | ||
AbstractVersionedDataset[go.Figure, Union[go.Figure, go.FigureWidget]] | ||
): | ||
"""``HTMLDataset`` saves a plotly figure to an HTML file using an | ||
underlying filesystem (e.g.: local, S3, GCS). | ||
Example usage for the | ||
`YAML API <https://kedro.readthedocs.io/en/stable/data/\ | ||
data_catalog_yaml_examples.html>`_: | ||
.. code-block:: yaml | ||
scatter_plot: | ||
type: plotly.HTMLDataset | ||
filepath: data/08_reporting/scatter_plot.html | ||
save_args: | ||
auto_open: False | ||
Example usage for the | ||
`Python API <https://kedro.readthedocs.io/en/stable/data/\ | ||
advanced_data_catalog_usage.html>`_: | ||
.. code-block:: pycon | ||
>>> from kedro_datasets.plotly import HTMLDataset | ||
>>> import plotly.express as px | ||
>>> | ||
>>> fig = px.bar(x=["a", "b", "c"], y=[1, 3, 2]) | ||
>>> dataset = HTMLDataset(filepath=tmp_path / "test.html") | ||
>>> dataset.save(fig) | ||
""" | ||
|
||
DEFAULT_SAVE_ARGS: dict[str, Any] = {} | ||
DEFAULT_FS_ARGS: dict[str, Any] = { | ||
"open_args_save": {"mode": "w", "encoding": "utf-8"} | ||
} | ||
|
||
def __init__( # noqa: PLR0913 | ||
self, | ||
*, | ||
filepath: str, | ||
save_args: dict[str, Any] | None = None, | ||
version: Version | None = None, | ||
credentials: dict[str, Any] | None = None, | ||
fs_args: dict[str, Any] | None = None, | ||
metadata: dict[str, Any] | None = None, | ||
) -> None: | ||
"""Creates a new instance of ``HTMLDataset`` pointing to a concrete HTML file | ||
on a specific filesystem. | ||
Args: | ||
filepath: Filepath in POSIX format to an HTML file prefixed with a protocol like `s3://`. | ||
If prefix is not provided `file` protocol (local filesystem) will be used. | ||
The prefix should be any protocol supported by ``fsspec``. | ||
Note: `http(s)` doesn't support versioning. | ||
save_args: Plotly options for saving HTML files. | ||
Here you can find all available arguments: | ||
https://plotly.com/python-api-reference/generated/plotly.io.write_html.html#plotly.io.write_html | ||
All defaults are preserved. | ||
version: If specified, should be an instance of | ||
``kedro.io.core.Version``. If its ``load`` attribute is | ||
None, the latest version will be loaded. If its ``save`` | ||
attribute is None, save version will be autogenerated. | ||
credentials: Credentials required to get access to the underlying filesystem. | ||
E.g. for ``GCSFileSystem`` it should look like `{'token': None}`. | ||
fs_args: Extra arguments to pass into underlying filesystem class constructor | ||
(e.g. `{"project": "my-project"}` for ``GCSFileSystem``), as well as | ||
to pass to the filesystem's `open` method through nested keys | ||
`open_args_load` and `open_args_save`. | ||
Here you can find all available arguments for `open`: | ||
https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.spec.AbstractFileSystem.open | ||
All defaults are preserved, except `mode`, which is set to `w` when | ||
saving. | ||
metadata: Any arbitrary metadata. | ||
This is ignored by Kedro, but may be consumed by users or external plugins. | ||
""" | ||
_fs_args = deepcopy(fs_args) or {} | ||
_fs_open_args_save = _fs_args.pop("open_args_save", {}) | ||
_credentials = deepcopy(credentials) or {} | ||
|
||
protocol, path = get_protocol_and_path(filepath, version) | ||
if protocol == "file": | ||
_fs_args.setdefault("auto_mkdir", True) | ||
|
||
self._protocol = protocol | ||
self._fs = fsspec.filesystem(self._protocol, **_credentials, **_fs_args) | ||
|
||
self.metadata = metadata | ||
|
||
super().__init__( | ||
filepath=PurePosixPath(path), | ||
version=version, | ||
exists_function=self._fs.exists, | ||
glob_function=self._fs.glob, | ||
) | ||
|
||
# Handle default save and fs arguments | ||
self._save_args = {**self.DEFAULT_SAVE_ARGS, **(save_args or {})} | ||
self._fs_open_args_save = { | ||
**self.DEFAULT_FS_ARGS.get("open_args_save", {}), | ||
**(_fs_open_args_save or {}), | ||
} | ||
|
||
def _describe(self) -> dict[str, Any]: | ||
return { | ||
"filepath": self._filepath, | ||
"protocol": self._protocol, | ||
"save_args": self._save_args, | ||
"version": self._version, | ||
} | ||
|
||
def _load(self) -> NoReturn: | ||
raise DatasetError(f"Loading not supported for '{self.__class__.__name__}'") | ||
|
||
def _save(self, data: go.Figure) -> None: | ||
save_path = get_filepath_str(self._get_save_path(), self._protocol) | ||
|
||
with self._fs.open(save_path, **self._fs_open_args_save) as fs_file: | ||
data.write_html(fs_file, **self._save_args) | ||
|
||
self._invalidate_cache() | ||
|
||
def _exists(self) -> bool: | ||
load_path = get_filepath_str(self._get_load_path(), self._protocol) | ||
|
||
return self._fs.exists(load_path) | ||
|
||
def _release(self) -> None: | ||
super()._release() | ||
self._invalidate_cache() | ||
|
||
def _invalidate_cache(self) -> None: | ||
filepath = get_filepath_str(self._filepath, self._protocol) | ||
self._fs.invalidate_cache(filepath) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
from pathlib import PurePosixPath | ||
|
||
import plotly.express as px | ||
import pytest | ||
from adlfs import AzureBlobFileSystem | ||
from fsspec.implementations.http import HTTPFileSystem | ||
from fsspec.implementations.local import LocalFileSystem | ||
from gcsfs import GCSFileSystem | ||
from kedro.io.core import PROTOCOL_DELIMITER, DatasetError | ||
from s3fs.core import S3FileSystem | ||
|
||
from kedro_datasets.plotly import HTMLDataset | ||
|
||
|
||
@pytest.fixture | ||
def filepath_html(tmp_path): | ||
return (tmp_path / "test.html").as_posix() | ||
|
||
|
||
@pytest.fixture | ||
def html_dataset(filepath_html, save_args, fs_args): | ||
return HTMLDataset( | ||
filepath=filepath_html, | ||
save_args=save_args, | ||
fs_args=fs_args, | ||
) | ||
|
||
|
||
@pytest.fixture | ||
def dummy_plot(): | ||
return px.scatter(x=[1, 2, 3], y=[1, 3, 2], title="Test") | ||
|
||
|
||
class TestHTMLDataset: | ||
def test_save(self, html_dataset, dummy_plot): | ||
"""Test saving and reloading the data set.""" | ||
html_dataset.save(dummy_plot) | ||
assert html_dataset._fs_open_args_save == {"mode": "w", "encoding": "utf-8"} | ||
|
||
def test_exists(self, html_dataset, dummy_plot): | ||
"""Test `exists` method invocation for both existing and | ||
nonexistent data set.""" | ||
assert not html_dataset.exists() | ||
html_dataset.save(dummy_plot) | ||
assert html_dataset.exists() | ||
|
||
def test_load_is_impossible(self, html_dataset): | ||
"""Check the error when trying to load a dataset.""" | ||
pattern = "Loading not supported" | ||
with pytest.raises(DatasetError, match=pattern): | ||
html_dataset.load() | ||
|
||
@pytest.mark.parametrize("save_args", [{"auto_play": False}]) | ||
def test_save_extra_params(self, html_dataset, save_args): | ||
"""Test overriding default save args""" | ||
for k, v in save_args.items(): | ||
assert html_dataset._save_args[k] == v | ||
|
||
@pytest.mark.parametrize( | ||
"filepath,instance_type,credentials", | ||
[ | ||
("s3://bucket/file.html", S3FileSystem, {}), | ||
("file:///tmp/test.html", LocalFileSystem, {}), | ||
("/tmp/test.html", LocalFileSystem, {}), | ||
("gcs://bucket/file.html", GCSFileSystem, {}), | ||
("https://example.com/file.html", HTTPFileSystem, {}), | ||
( | ||
"abfs://bucket/file.csv", | ||
AzureBlobFileSystem, | ||
{"account_name": "test", "account_key": "test"}, | ||
), | ||
], | ||
) | ||
def test_protocol_usage(self, filepath, instance_type, credentials): | ||
dataset = HTMLDataset(filepath=filepath, credentials=credentials) | ||
assert isinstance(dataset._fs, instance_type) | ||
|
||
path = filepath.split(PROTOCOL_DELIMITER, 1)[-1] | ||
|
||
assert str(dataset._filepath) == path | ||
assert isinstance(dataset._filepath, PurePosixPath) | ||
|
||
def test_catalog_release(self, mocker): | ||
fs_mock = mocker.patch("fsspec.filesystem").return_value | ||
filepath = "test.html" | ||
dataset = HTMLDataset(filepath=filepath) | ||
dataset.release() | ||
fs_mock.invalidate_cache.assert_called_once_with(filepath) |