Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update kedro-datasets to match with kedro.extras.datasets #74

Merged
merged 8 commits into from
Nov 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 24 additions & 4 deletions kedro-datasets/kedro_datasets/api/api_dataset.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,38 @@
"""``APIDataSet`` loads the data from HTTP(S) APIs.
It uses the python requests library: https://requests.readthedocs.io/en/latest/
"""
from typing import Any, Dict, Iterable, List, Union
from typing import Any, Dict, Iterable, List, NoReturn, Union

import requests
from kedro.io.core import AbstractDataSet, DataSetError
from requests.auth import AuthBase


class APIDataSet(AbstractDataSet):
class APIDataSet(AbstractDataSet[None, requests.Response]):
"""``APIDataSet`` loads the data from HTTP(S) APIs.
It uses the python requests library: https://requests.readthedocs.io/en/latest/

Example:
Example adding a catalog entry with
`YAML API
<https://kedro.readthedocs.io/en/stable/data/\
data_catalog.html#use-the-data-catalog-with-the-yaml-api>`_:

.. code-block:: yaml

>>> usda:
>>> type: api.APIDataSet
>>> url: https://quickstats.nass.usda.gov
>>> params:
>>> key: SOME_TOKEN,
>>> format: JSON,
>>> commodity_desc: CORN,
>>> statisticcat_des: YIELD,
>>> agg_level_desc: STATE,
>>> year: 2000
>>>


Example using Python API:
::

>>> from kedro_datasets.api import APIDataSet
Expand Down Expand Up @@ -108,7 +128,7 @@ def _execute_request(self) -> requests.Response:
def _load(self) -> requests.Response:
return self._execute_request()

def _save(self, data: Any) -> None:
def _save(self, data: None) -> NoReturn:
raise DataSetError(f"{self.__class__.__name__} is a read only data set type")

def _exists(self) -> bool:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from kedro.io.core import AbstractDataSet, get_filepath_str, get_protocol_and_path


class BioSequenceDataSet(AbstractDataSet):
class BioSequenceDataSet(AbstractDataSet[List, List]):
r"""``BioSequenceDataSet`` loads and saves data to a sequence file.

Example:
Expand Down
23 changes: 21 additions & 2 deletions kedro-datasets/kedro_datasets/dask/parquet_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,31 @@
from kedro.io.core import AbstractDataSet, get_protocol_and_path


class ParquetDataSet(AbstractDataSet):
class ParquetDataSet(AbstractDataSet[dd.DataFrame, dd.DataFrame]):
"""``ParquetDataSet`` loads and saves data to parquet file(s). It uses Dask
remote data services to handle the corresponding load and save operations:
https://docs.dask.org/en/latest/how-to/connect-to-remote-data.html

Example (AWS S3):
Example adding a catalog entry with
`YAML API
<https://kedro.readthedocs.io/en/stable/data/\
data_catalog.html#use-the-data-catalog-with-the-yaml-api>`_:

.. code-block:: yaml

>>> cars:
>>> type: dask.ParquetDataSet
>>> filepath: s3://bucket_name/path/to/folder
>>> save_args:
>>> compression: GZIP
>>> credentials:
>>> client_kwargs:
>>> aws_access_key_id: YOUR_KEY
>>> aws_secret_access_key: YOUR_SECRET
>>>


Example using Python API (AWS S3):
::

>>> from kedro_datasets.dask import ParquetDataSet
Expand Down
3 changes: 1 addition & 2 deletions kedro-datasets/kedro_datasets/email/message_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@


class EmailMessageDataSet(
AbstractVersionedDataSet
AbstractVersionedDataSet[Message, Message]
): # pylint: disable=too-many-instance-attributes
"""``EmailMessageDataSet`` loads/saves an email message from/to a file
using an underlying filesystem (e.g.: local, S3, GCS). It uses the
Expand All @@ -45,7 +45,6 @@ class EmailMessageDataSet(
>>> msg["From"] = '"sin studly17"'
>>> msg["To"] = '"strong bad"'
>>>
>>> # data_set = EmailMessageDataSet(filepath="gcs://bucket/test")
>>> data_set = EmailMessageDataSet(filepath="test")
>>> data_set.save(msg)
>>> reloaded = data_set.load()
Expand Down
11 changes: 6 additions & 5 deletions kedro-datasets/kedro_datasets/geopandas/geojson_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@
)


class GeoJSONDataSet(AbstractVersionedDataSet):
class GeoJSONDataSet(
AbstractVersionedDataSet[
gpd.GeoDataFrame, Union[gpd.GeoDataFrame, Dict[str, gpd.GeoDataFrame]]
]
):
"""``GeoJSONDataSet`` loads/saves data to a GeoJSON file using an underlying filesystem
(eg: local, S3, GCS).
The underlying functionality is supported by geopandas, so it supports all
Expand All @@ -32,10 +36,7 @@ class GeoJSONDataSet(AbstractVersionedDataSet):
>>>
>>> data = gpd.GeoDataFrame({'col1': [1, 2], 'col2': [4, 5],
>>> 'col3': [5, 6]}, geometry=[Point(1,1), Point(2,4)])
>>> # data_set = GeoJSONDataSet(filepath="gcs://bucket/test.geojson",
>>> save_args=None)
>>> data_set = GeoJSONDataSet(filepath="test.geojson",
>>> save_args=None)
>>> data_set = GeoJSONDataSet(filepath="test.geojson", save_args=None)
>>> data_set.save(data)
>>> reloaded = data_set.load()
>>>
Expand Down
6 changes: 3 additions & 3 deletions kedro-datasets/kedro_datasets/holoviews/holoviews_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import io
from copy import deepcopy
from pathlib import PurePosixPath
from typing import Any, Dict, TypeVar
from typing import Any, Dict, NoReturn, TypeVar

import fsspec
import holoviews as hv
Expand All @@ -20,7 +20,7 @@
HoloViews = TypeVar("HoloViews")


class HoloviewsWriter(AbstractVersionedDataSet):
class HoloviewsWriter(AbstractVersionedDataSet[HoloViews, NoReturn]):
"""``HoloviewsWriter`` saves Holoviews objects to image file(s) in an underlying
filesystem (e.g. local, S3, GCS).

Expand Down Expand Up @@ -105,7 +105,7 @@ def _describe(self) -> Dict[str, Any]:
version=self._version,
)

def _load(self) -> str:
def _load(self) -> NoReturn:
raise DataSetError(f"Loading not supported for '{self.__class__.__name__}'")

def _save(self, data: HoloViews) -> None:
Expand Down
11 changes: 3 additions & 8 deletions kedro-datasets/kedro_datasets/json/json_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
)


class JSONDataSet(AbstractVersionedDataSet):
class JSONDataSet(AbstractVersionedDataSet[Any, Any]):
"""``JSONDataSet`` loads/saves data from/to a JSON file using an underlying
filesystem (e.g.: local, S3, GCS). It uses native json to handle the JSON file.

Expand All @@ -27,17 +27,13 @@ class JSONDataSet(AbstractVersionedDataSet):
>>> json_dataset:
>>> type: json.JSONDataSet
>>> filepath: data/01_raw/location.json
>>> load_args:
>>> lines: True
>>>
>>> cars:
>>> type: json.JSONDataSet
>>> filepath: gcs://your_bucket/cars.json
>>> fs_args:
>>> project: my-project
>>> credentials: my_gcp_credentials
>>> load_args:
>>> lines: True

Example using Python API:
::
Expand All @@ -46,7 +42,6 @@ class JSONDataSet(AbstractVersionedDataSet):
>>>
>>> data = {'col1': [1, 2], 'col2': [4, 5], 'col3': [5, 6]}
>>>
>>> # data_set = JSONDataSet(filepath="gcs://bucket/test.json")
>>> data_set = JSONDataSet(filepath="test.json")
>>> data_set.save(data)
>>> reloaded = data_set.load()
Expand Down Expand Up @@ -128,13 +123,13 @@ def _describe(self) -> Dict[str, Any]:
version=self._version,
)

def _load(self) -> Dict:
def _load(self) -> Any:
load_path = get_filepath_str(self._get_load_path(), self._protocol)

with self._fs.open(load_path, **self._fs_open_args_load) as fs_file:
return json.load(fs_file)

def _save(self, data: Dict) -> None:
def _save(self, data: Any) -> None:
save_path = get_filepath_str(self._get_save_path(), self._protocol)

with self._fs.open(save_path, **self._fs_open_args_save) as fs_file:
Expand Down
82 changes: 59 additions & 23 deletions kedro-datasets/kedro_datasets/matplotlib/matplotlib_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import io
from copy import deepcopy
from pathlib import PurePosixPath
from typing import Any, Dict, List, Union
from typing import Any, Dict, List, NoReturn, Union
from warnings import warn

import fsspec
Expand All @@ -18,55 +18,91 @@
)


class MatplotlibWriter(AbstractVersionedDataSet):
class MatplotlibWriter(
AbstractVersionedDataSet[
Union[plt.figure, List[plt.figure], Dict[str, plt.figure]], NoReturn
]
):
"""``MatplotlibWriter`` saves one or more Matplotlib objects as
image files to an underlying filesystem (e.g. local, S3, GCS).

Example:
Example adding a catalog entry with the `YAML API
<https://kedro.readthedocs.io/en/stable/data/\
data_catalog.html#use-the-data-catalog-with-the-yaml-api>`_:

.. code-block:: yaml

>>> output_plot:
>>> type: matplotlib.MatplotlibWriter
>>> filepath: data/08_reporting/output_plot.png
>>> save_args:
>>> format: png
>>>

Example using the Python API:

::

>>> import matplotlib.pyplot as plt
>>> from kedro_datasets.matplotlib import MatplotlibWriter
>>>
>>> # Saving single plot
>>> fig = plt.figure()
>>> plt.plot([1, 2, 3], [4, 5, 6])
>>> single_plot_writer = MatplotlibWriter(
>>> filepath="matplot_lib_single_plot.png"
>>> plt.plot([1, 2, 3])
>>> plot_writer = MatplotlibWriter(
>>> filepath="data/08_reporting/output_plot.png"
>>> )
>>> plt.close()
>>> single_plot_writer.save(fig)
>>> plot_writer.save(fig)

Example saving a plot as a PDF file:

::

>>> import matplotlib.pyplot as plt
>>> from kedro_datasets.matplotlib import MatplotlibWriter
>>>
>>> # MatplotlibWriter can output other formats as well, such as PDF files.
>>> # For this, we need to specify the format:
>>> fig = plt.figure()
>>> plt.plot([1, 2, 3], [4, 5, 6])
>>> single_plot_writer = MatplotlibWriter(
>>> filepath="matplot_lib_single_plot.pdf",
>>> plt.plot([1, 2, 3])
>>> pdf_plot_writer = MatplotlibWriter(
>>> filepath="data/08_reporting/output_plot.pdf",
>>> save_args={"format": "pdf"},
>>> )
>>> plt.close()
>>> single_plot_writer.save(fig)
>>> pdf_plot_writer.save(fig)


Example saving multiple plots in a folder, using a dictionary:

::

>>> import matplotlib.pyplot as plt
>>> from kedro_datasets.matplotlib import MatplotlibWriter
>>>
>>> # Saving dictionary of plots
>>> plots_dict = dict()
>>> for colour in ["blue", "green", "red"]:
>>> plots_dict[colour] = plt.figure()
>>> plt.plot([1, 2, 3], [4, 5, 6], color=colour)
>>> plots_dict[f"{colour}.png"] = plt.figure()
>>> plt.plot([1, 2, 3], color=colour)
>>>
>>> plt.close("all")
>>> dict_plot_writer = MatplotlibWriter(
>>> filepath="matplotlib_dict"
>>> filepath="data/08_reporting/plots"
>>> )
>>> dict_plot_writer.save(plots_dict)

Example saving multiple plots in a folder, using a list:

::

>>> import matplotlib.pyplot as plt
>>> from kedro_datasets.matplotlib import MatplotlibWriter
>>>
>>> # Saving list of plots
>>> plots_list = []
>>> for index in range(5):
>>> for i in range(5):
>>> plots_list.append(plt.figure())
>>> plt.plot([1,2,3],[4,5,6])
>>> plt.plot([i, i + 1, i + 2])
>>> plt.close("all")
>>> list_plot_writer = MatplotlibWriter(
>>> filepath="matplotlib_list"
>>> filepath="data/08_reporting/plots"
>>> )
>>> list_plot_writer.save(plots_list)

Expand Down Expand Up @@ -152,7 +188,7 @@ def _describe(self) -> Dict[str, Any]:
version=self._version,
)

def _load(self) -> None:
def _load(self) -> NoReturn:
raise DataSetError(f"Loading not supported for '{self.__class__.__name__}'")

def _save(
Expand Down
2 changes: 1 addition & 1 deletion kedro-datasets/kedro_datasets/networkx/gml_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
)


class GMLDataSet(AbstractVersionedDataSet):
class GMLDataSet(AbstractVersionedDataSet[networkx.Graph, networkx.Graph]):
"""``GMLDataSet`` loads and saves graphs to a GML file using an
underlying filesystem (e.g.: local, S3, GCS). ``NetworkX`` is used to
create GML data.
Expand Down
5 changes: 2 additions & 3 deletions kedro-datasets/kedro_datasets/networkx/graphml_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
)


class GraphMLDataSet(AbstractVersionedDataSet):
class GraphMLDataSet(AbstractVersionedDataSet[networkx.Graph, networkx.Graph]):
"""``GraphMLDataSet`` loads and saves graphs to a GraphML file using an
underlying filesystem (e.g.: local, S3, GCS). ``NetworkX`` is used to
create GraphML data.
Expand Down Expand Up @@ -107,8 +107,7 @@ def __init__(
def _load(self) -> networkx.Graph:
load_path = get_filepath_str(self._get_load_path(), self._protocol)
with self._fs.open(load_path, **self._fs_open_args_load) as fs_file:
data = networkx.read_graphml(fs_file, **self._load_args)
return data
return networkx.read_graphml(fs_file, **self._load_args)

def _save(self, data: networkx.Graph) -> None:
save_path = get_filepath_str(self._get_save_path(), self._protocol)
Expand Down
2 changes: 1 addition & 1 deletion kedro-datasets/kedro_datasets/networkx/json_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
)


class JSONDataSet(AbstractVersionedDataSet):
class JSONDataSet(AbstractVersionedDataSet[networkx.Graph, networkx.Graph]):
"""NetworkX ``JSONDataSet`` loads and saves graphs to a JSON file using an
underlying filesystem (e.g.: local, S3, GCS). ``NetworkX`` is used to
create JSON data.
Expand Down
Loading