Skip to content

Commit

Permalink
Merge branch 'main' into 371-telemetry-fix-masking-of-cli-commands
Browse files Browse the repository at this point in the history
  • Loading branch information
DimedS committed Feb 9, 2024
2 parents 516ff75 + 004888c commit ab42422
Show file tree
Hide file tree
Showing 18 changed files with 175 additions and 13 deletions.
1 change: 1 addition & 0 deletions kedro-datasets/RELEASE.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Upcoming Release
## Major features and improvements
* Added `MatlabDataset` which uses `scipy` to save and load `.mat` files.
* Extend preview feature for matplotlib, plotly and tracking datasets.

## Bug fixes and other changes
* Removed Windows specific conditions in `pandas.HDFDataset` extra dependencies
Expand Down
1 change: 1 addition & 0 deletions kedro-datasets/docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@

intersphinx_mapping = {
"kedro": ("https://docs.kedro.org/en/stable/", None),
"python": ("https://docs.python.org/3.9/", None)
}

type_targets = {
Expand Down
16 changes: 16 additions & 0 deletions kedro-datasets/kedro_datasets/_typing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
"""
`_typing.py` defines custom data types for Kedro-viz integration. It uses NewType from the typing module.
These types are used to facilitate data rendering in the Kedro-viz front-end.
"""

from typing import NewType

TablePreview = NewType("TablePreview", dict)
ImagePreview = NewType("ImagePreview", bytes)
PlotlyPreview = NewType("PlotlyPreview", dict)
JSONPreview = NewType("JSONPreview", dict)


# experiment tracking datasets types
MetricsTrackingPreview = NewType("MetricsTrackingPreview", dict)
JSONTrackingPreview = NewType("JSONTrackingPreview", dict)
15 changes: 15 additions & 0 deletions kedro-datasets/kedro_datasets/matplotlib/matplotlib_writer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""``MatplotlibWriter`` saves one or more Matplotlib objects as image
files to an underlying filesystem (e.g. local, S3, GCS)."""

import base64
import io
from copy import deepcopy
from pathlib import PurePosixPath
Expand All @@ -17,6 +18,8 @@
get_protocol_and_path,
)

from kedro_datasets._typing import ImagePreview


class MatplotlibWriter(
AbstractVersionedDataset[
Expand Down Expand Up @@ -245,3 +248,15 @@ def _invalidate_cache(self) -> None:
"""Invalidate underlying filesystem caches."""
filepath = get_filepath_str(self._filepath, self._protocol)
self._fs.invalidate_cache(filepath)

def preview(self) -> ImagePreview:
"""
Generates a preview of the matplotlib dataset as a base64 encoded image.
Returns:
str: A base64 encoded string representing the matplotlib plot image.
"""
load_path = get_filepath_str(self._get_load_path(), self._protocol)
with self._fs.open(load_path, mode="rb") as img_file:
base64_bytes = base64.b64encode(img_file.read())
return base64_bytes.decode("utf-8")
13 changes: 12 additions & 1 deletion kedro-datasets/kedro_datasets/pandas/csv_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
get_protocol_and_path,
)

from kedro_datasets._typing import TablePreview

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -194,7 +196,16 @@ def _invalidate_cache(self) -> None:
filepath = get_filepath_str(self._filepath, self._protocol)
self._fs.invalidate_cache(filepath)

def _preview(self, nrows: int = 40) -> dict:
def preview(self, nrows: int = 5) -> TablePreview:
"""
Generate a preview of the dataset with a specified number of rows.
Args:
nrows: The number of rows to include in the preview. Defaults to 5.
Returns:
dict: A dictionary containing the data in a split format.
"""
# Create a copy so it doesn't contaminate the original dataset
dataset_copy = self._copy()
dataset_copy._load_args["nrows"] = nrows
Expand Down
13 changes: 12 additions & 1 deletion kedro-datasets/kedro_datasets/pandas/excel_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
get_protocol_and_path,
)

from kedro_datasets._typing import TablePreview

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -262,7 +264,16 @@ def _invalidate_cache(self) -> None:
filepath = get_filepath_str(self._filepath, self._protocol)
self._fs.invalidate_cache(filepath)

def _preview(self, nrows: int = 40) -> dict:
def preview(self, nrows: int = 5) -> TablePreview:
"""
Generate a preview of the dataset with a specified number of rows.
Args:
nrows: The number of rows to include in the preview. Defaults to 5.
Returns:
dict: A dictionary containing the data in a split format.
"""
# Create a copy so it doesn't contaminate the original dataset
dataset_copy = self._copy()
dataset_copy._load_args["nrows"] = nrows
Expand Down
14 changes: 14 additions & 0 deletions kedro-datasets/kedro_datasets/plotly/json_dataset.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""``JSONDataset`` loads/saves a plotly figure from/to a JSON file using an underlying
filesystem (e.g.: local, S3, GCS).
"""
import json
from copy import deepcopy
from pathlib import PurePosixPath
from typing import Any, Union
Expand All @@ -15,6 +16,8 @@
)
from plotly import graph_objects as go

from kedro_datasets._typing import PlotlyPreview


class JSONDataset(
AbstractVersionedDataset[go.Figure, Union[go.Figure, go.FigureWidget]]
Expand Down Expand Up @@ -167,3 +170,14 @@ def _release(self) -> None:
def _invalidate_cache(self) -> None:
filepath = get_filepath_str(self._filepath, self._protocol)
self._fs.invalidate_cache(filepath)

def preview(self) -> PlotlyPreview:
"""
Generates a preview of the plotly dataset.
Returns:
dict: A dictionary containing the plotly data.
"""
load_path = get_filepath_str(self._get_load_path(), self._protocol)
with self._fs.open(load_path, **self._fs_open_args_load) as fs_file:
return json.load(fs_file)
15 changes: 14 additions & 1 deletion kedro-datasets/kedro_datasets/plotly/plotly_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,16 @@
file using an underlying filesystem (e.g.: local, S3, GCS). It loads the JSON into a
plotly figure.
"""
import json
from copy import deepcopy
from typing import Any

import pandas as pd
import plotly.express as px
from kedro.io.core import Version
from kedro.io.core import Version, get_filepath_str
from plotly import graph_objects as go

from kedro_datasets._typing import PlotlyPreview
from kedro_datasets.plotly.json_dataset import JSONDataset


Expand Down Expand Up @@ -148,3 +150,14 @@ def _plot_dataframe(self, data: pd.DataFrame) -> go.Figure:
fig.update_layout(template=self._plotly_args.get("theme", "plotly"))
fig.update_layout(self._plotly_args.get("layout", {}))
return fig

def preview(self) -> PlotlyPreview:
"""
Generates a preview of the plotly dataset.
Returns:
dict: A dictionary containing the plotly data.
"""
load_path = get_filepath_str(self._get_load_path(), self._protocol)
with self._fs.open(load_path, **self._fs_open_args_load) as fs_file:
return json.load(fs_file)
11 changes: 10 additions & 1 deletion kedro-datasets/kedro_datasets/tracking/json_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
filesystem (e.g.: local, S3, GCS). It uses native json to handle the JSON file.
The ``JSONDataset`` is part of Kedro Experiment Tracking. The dataset is versioned by default.
"""
import json
from typing import NoReturn

from kedro.io.core import DatasetError
from kedro.io.core import DatasetError, get_filepath_str

from kedro_datasets._typing import JSONTrackingPreview
from kedro_datasets.json import json_dataset


Expand Down Expand Up @@ -44,3 +46,10 @@ class JSONDataset(json_dataset.JSONDataset):

def _load(self) -> NoReturn:
raise DatasetError(f"Loading not supported for '{self.__class__.__name__}'")

def preview(self) -> JSONTrackingPreview:
"Load the JSON tracking dataset used in Kedro-viz experiment tracking."
load_path = get_filepath_str(self._get_load_path(), self._protocol)

with self._fs.open(load_path, **self._fs_open_args_load) as fs_file:
return json.load(fs_file)
8 changes: 8 additions & 0 deletions kedro-datasets/kedro_datasets/tracking/metrics_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from kedro.io.core import DatasetError, get_filepath_str

from kedro_datasets._typing import MetricsTrackingPreview
from kedro_datasets.json import json_dataset


Expand Down Expand Up @@ -65,3 +66,10 @@ def _save(self, data: dict[str, float]) -> None:
json.dump(data, fs_file, **self._save_args)

self._invalidate_cache()

def preview(self) -> MetricsTrackingPreview:
"Load the Metrics tracking dataset used in Kedro-viz experiment tracking"
load_path = get_filepath_str(self._get_load_path(), self._protocol)

with self._fs.open(load_path, **self._fs_open_args_load) as fs_file:
return json.load(fs_file)
7 changes: 2 additions & 5 deletions kedro-datasets/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,7 @@ def _collect_requirements(requires):
}
redis_require = {"redis.PickleDataset": ["redis~=4.1"]}
snowflake_require = {
"snowflake.SnowparkTableDataset": [
"snowflake-snowpark-python~=1.0",
"pyarrow~=8.0",
]
"snowflake.SnowparkTableDataset": ["snowflake-snowpark-python~=1.0"]
}
spark_require = {
"spark.SparkDataset": [SPARK, HDFS, S3FS],
Expand Down Expand Up @@ -184,7 +181,7 @@ def _collect_requirements(requires):
"coverage[toml]",
"dask[complete]>=2021.10",
"delta-spark>=1.0, <3.0",
"deltalake>=0.10.0",
"deltalake>=0.10.0, <0.15.2", # temporary pin as 0.15.2 breaks some of our tests
"dill~=0.3.1",
"filelock>=3.4.0, <4.0",
"gcsfs>=2023.1, <2023.3",
Expand Down
12 changes: 12 additions & 0 deletions kedro-datasets/tests/matplotlib/test_matplotlib_writer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import inspect
import json
from pathlib import Path

Expand Down Expand Up @@ -253,6 +254,17 @@ def test_release(self, mocker):
dataset.release()
fs_mock.invalidate_cache.assert_called_once_with(f"{BUCKET_NAME}/{KEY_PATH}")

def test_preview(self, mock_single_plot, plot_writer):
plot_writer.save(mock_single_plot)
# Define the expected beginning of the base64 encoded image string
expected_beginning = "iVBORw0KGgoAAAANSUh"
preview = plot_writer.preview()
assert preview.startswith(expected_beginning)
assert (
inspect.signature(plot_writer.preview).return_annotation.__name__
== "ImagePreview"
)


class TestMatplotlibWriterVersioned:
def test_version_str_repr(self, load_version, save_version):
Expand Down
9 changes: 7 additions & 2 deletions kedro-datasets/tests/pandas/test_csv_dataset.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import inspect
import os
import sys
from pathlib import Path, PurePosixPath
Expand Down Expand Up @@ -174,10 +175,14 @@ def test_storage_options_dropped(self, load_args, save_args, caplog, tmp_path):
],
)
def test_preview(self, csv_dataset, dummy_dataframe, nrows, expected):
"""Test _preview returns the correct data structure."""
"""Test preview returns the correct data structure."""
csv_dataset.save(dummy_dataframe)
previewed = csv_dataset._preview(nrows=nrows)
previewed = csv_dataset.preview(nrows=nrows)
assert previewed == expected
assert (
inspect.signature(csv_dataset.preview).return_annotation.__name__
== "TablePreview"
)

def test_load_missing_file(self, csv_dataset):
"""Check the error when trying to load missing file."""
Expand Down
9 changes: 7 additions & 2 deletions kedro-datasets/tests/pandas/test_excel_dataset.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import inspect
from pathlib import Path, PurePosixPath

import pandas as pd
Expand Down Expand Up @@ -158,10 +159,14 @@ def test_storage_options_dropped(self, load_args, save_args, caplog, tmp_path):
],
)
def test_preview(self, excel_dataset, dummy_dataframe, nrows, expected):
"""Test _preview returns the correct data structure."""
"""Test preview returns the correct data structure."""
excel_dataset.save(dummy_dataframe)
previewed = excel_dataset._preview(nrows=nrows)
previewed = excel_dataset.preview(nrows=nrows)
assert previewed == expected
assert (
inspect.signature(excel_dataset.preview).return_annotation.__name__
== "TablePreview"
)

def test_load_missing_file(self, excel_dataset):
"""Check the error when trying to load missing file."""
Expand Down
11 changes: 11 additions & 0 deletions kedro-datasets/tests/plotly/test_json_dataset.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import inspect
from pathlib import PurePosixPath

import plotly.express as px
Expand Down Expand Up @@ -98,3 +99,13 @@ def test_catalog_release(self, mocker):
dataset = JSONDataset(filepath=filepath)
dataset.release()
fs_mock.invalidate_cache.assert_called_once_with(filepath)

def test_preview(self, json_dataset, dummy_plot):
json_dataset.save(dummy_plot)
preview = json_dataset.preview()
assert (
inspect.signature(json_dataset.preview).return_annotation.__name__
== "PlotlyPreview"
)
assert "data" in preview
assert "layout" in preview
11 changes: 11 additions & 0 deletions kedro-datasets/tests/plotly/test_plotly_dataset.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import inspect
from pathlib import PurePosixPath

import pandas as pd
Expand Down Expand Up @@ -105,3 +106,13 @@ def test_fail_if_invalid_plotly_args_provided(self):
dataset = PlotlyDataset(filepath=filepath, plotly_args=plotly_args)
with pytest.raises(DatasetError):
dataset.save(dummy_dataframe)

def test_preview(self, plotly_dataset, dummy_dataframe):
plotly_dataset.save(dummy_dataframe)
preview = plotly_dataset.preview()
assert (
inspect.signature(plotly_dataset.preview).return_annotation.__name__
== "PlotlyPreview"
)
assert "data" in preview
assert "layout" in preview
11 changes: 11 additions & 0 deletions kedro-datasets/tests/tracking/test_json_dataset.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import inspect
import json
from pathlib import Path, PurePosixPath

Expand Down Expand Up @@ -182,3 +183,13 @@ def test_http_filesystem_no_versioning(self):
JSONDataset(
filepath="https://example.com/file.json", version=Version(None, None)
)

def test_preview(self, json_dataset, dummy_data):
expected_preview = {"col1": 1, "col2": 2, "col3": "mystring"}
json_dataset.save(dummy_data)
preview = json_dataset.preview()
assert preview == expected_preview
assert (
inspect.signature(json_dataset.preview).return_annotation.__name__
== "JSONTrackingPreview"
)
11 changes: 11 additions & 0 deletions kedro-datasets/tests/tracking/test_metrics_dataset.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import inspect
import json
from pathlib import Path, PurePosixPath

Expand Down Expand Up @@ -191,3 +192,13 @@ def test_http_filesystem_no_versioning(self):
MetricsDataset(
filepath="https://example.com/file.json", version=Version(None, None)
)

def test_preview(self, metrics_dataset, dummy_data):
expected_preview = {"col1": 1, "col2": 2, "col3": 3}
metrics_dataset.save(dummy_data)
preview = metrics_dataset.preview()
assert preview == expected_preview
assert (
inspect.signature(metrics_dataset.preview).return_annotation.__name__
== "MetricsTrackingPreview"
)

0 comments on commit ab42422

Please sign in to comment.