Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modularize default load and save argument handling #15

Merged
merged 24 commits into from
Jul 23, 2019
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
b9bf25d
Modularize default load and save argument handling
deepyaman Jun 8, 2019
ba18548
Suppress ``super-init-not-called`` pylint messages
deepyaman Jun 9, 2019
41b40b2
Copy default args to prevent accidental mutation
deepyaman Jun 14, 2019
c10a654
Restore ``super().__init__`` given default arg fix
deepyaman Jun 14, 2019
bf2643f
Merge branch 'develop' into fix/default-args
deepyaman Jul 2, 2019
e83502c
Refactor abstract base class modification as mixin
deepyaman Jul 2, 2019
63fda57
Homogenize default load and save argument handling
deepyaman Jul 3, 2019
0505773
Demarcate load and save argument handling :dragon:
deepyaman Jul 3, 2019
a93abf2
Cover load and save argument handling :paw_prints:
deepyaman Jul 3, 2019
4226c2e
Add tests to cover load/save argument conditionals
deepyaman Jul 3, 2019
a17ae9e
Fix non-ASCII characters in legal header :pencil2:
deepyaman Jul 3, 2019
f7b2373
Remove load/save defaults from ``AbstractDataSet``
deepyaman Jul 7, 2019
124d663
Call ``super().__init__`` in mix-in implementation
deepyaman Jul 9, 2019
d3c7153
Fix MRO when subclassing ``DefaultArgumentsMixIn``
deepyaman Jul 9, 2019
da10346
Merge branch 'fix/default-args' of https://github.com/deepyaman/kedro…
deepyaman Jul 10, 2019
cac0c78
Copy default argument dicts with ``copy.deepcopy``
deepyaman Jul 10, 2019
681beb0
Merge branch 'develop' into fix/default-args
deepyaman Jul 10, 2019
0d31b7c
Merge branch 'develop' of https://github.com/quantumblacklabs/kedro i…
deepyaman Jul 10, 2019
473d725
Merge branch 'develop' into fix/default-args
deepyaman Jul 10, 2019
2a575d6
Merge branch 'develop' into fix/default-args
deepyaman Jul 10, 2019
5896daa
Annotate types for default load and save arguments
deepyaman Jul 10, 2019
3931744
Revert "Annotate types for default load and save arguments"
deepyaman Jul 10, 2019
b2e4c1c
Annotate types for default load and save arguments
deepyaman Jul 10, 2019
184d9f7
Merge branch 'develop' into fix/default-args
deepyaman Jul 18, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions kedro/contrib/io/azure/csv_blob.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ class CSVBlobDataSet(AbstractDataSet):
>>> assert data.equals(reloaded)
"""

DEFAULT_SAVE_ARGS = {"index": False}

def _describe(self) -> Dict[str, Any]:
return dict(
filepath=self._filepath,
Expand Down Expand Up @@ -106,16 +108,12 @@ def __init__(
All defaults are preserved, but "index", which is set to False.

"""
default_save_args = {"index": False}
self._save_args = (
{**default_save_args, **save_args} if save_args else default_save_args
)
self._load_args = load_args if load_args else {}
self._filepath = filepath
self._container_name = container_name
self._credentials = credentials if credentials else {}
self._blob_to_text_args = blob_to_text_args if blob_to_text_args else {}
self._blob_from_text_args = blob_from_text_args if blob_from_text_args else {}
super().__init__(load_args, save_args)

def _load(self) -> pd.DataFrame:
blob_service = BlockBlobService(**self._credentials)
Expand Down
13 changes: 1 addition & 12 deletions kedro/contrib/io/bioinformatics/sequence_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,18 +95,7 @@ def __init__(

"""
self._filepath = filepath
default_load_args = {}
default_save_args = {}
self._load_args = (
{**default_load_args, **load_args}
if load_args is not None
else default_load_args
)
self._save_args = (
{**default_save_args, **save_args}
if save_args is not None
else default_save_args
)
super().__init__(load_args, save_args)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer calling super at the top of the constructor, so the subclass would overwrite stuff from the parent, as a "specialisation" of the superclass.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer calling super at the top of the constructor, so the subclass would overwrite stuff from the parent, as a "specialisation" of the superclass.

Fair argument. I just left it in the same place where default arguments were previously handled (as close to the original as I could), but that makes sense.


def _load(self) -> List:
return list(SeqIO.parse(self._filepath, **self._load_args))
Expand Down
3 changes: 1 addition & 2 deletions kedro/contrib/io/pyspark/spark_data_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,7 @@ def __init__(

self._filepath = filepath
self._file_format = file_format
self._load_args = load_args if load_args is not None else {}
self._save_args = save_args if save_args is not None else {}
super().__init__(load_args, save_args)

@staticmethod
def _get_spark():
Expand Down
3 changes: 1 addition & 2 deletions kedro/contrib/io/pyspark/spark_jdbc.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,7 @@ def __init__(

self._url = url
self._table = table
self._load_args = load_args if load_args is not None else {}
self._save_args = save_args if save_args is not None else {}
super().__init__(load_args, save_args)
deepyaman marked this conversation as resolved.
Show resolved Hide resolved

# Update properties in load_args and save_args with credentials.
if credentials is not None:
Expand Down
21 changes: 20 additions & 1 deletion kedro/io/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from datetime import datetime, timezone
from glob import iglob
from pathlib import Path, PurePosixPath
from typing import Any, Dict, Type
from typing import Any, Dict, Optional, Type
from warnings import warn

from kedro.utils import load_obj
Expand Down Expand Up @@ -101,6 +101,9 @@ class AbstractDataSet(abc.ABC):
>>> return dict(param1=self._param1, param2=self._param2)
"""

DEFAULT_LOAD_ARGS = {}
deepyaman marked this conversation as resolved.
Show resolved Hide resolved
DEFAULT_SAVE_ARGS = {}

@classmethod
def from_config(
cls: Type,
Expand Down Expand Up @@ -189,6 +192,22 @@ def from_config(
)
return data_set

def __init__(
self,
load_args: Optional[Dict[str, Any]] = None,
save_args: Optional[Dict[str, Any]] = None,
) -> None:
self._load_args = (
{**self.DEFAULT_LOAD_ARGS, **load_args}
if load_args is not None
else self.DEFAULT_LOAD_ARGS
deepyaman marked this conversation as resolved.
Show resolved Hide resolved
)
self._save_args = (
{**self.DEFAULT_SAVE_ARGS, **save_args}
if save_args is not None
else self.DEFAULT_SAVE_ARGS
)

def load(self) -> Any:
"""Loads data by delegation to the provided load method.

Expand Down
15 changes: 3 additions & 12 deletions kedro/io/csv_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ class CSVLocalDataSet(AbstractDataSet, FilepathVersionMixIn):

"""

DEFAULT_SAVE_ARGS = {"index": False}

def _describe(self) -> Dict[str, Any]:
return dict(
filepath=self._filepath,
Expand Down Expand Up @@ -94,19 +96,8 @@ def __init__(
None, the latest version will be loaded. If its ``save``
attribute is None, save version will be autogenerated.
"""
default_save_args = {"index": False}
default_load_args = {}
self._filepath = filepath
self._load_args = (
{**default_load_args, **load_args}
if load_args is not None
else default_load_args
)
self._save_args = (
{**default_save_args, **save_args}
if save_args is not None
else default_save_args
)
super().__init__(load_args, save_args)
self._version = version

def _load(self) -> pd.DataFrame:
Expand Down
8 changes: 3 additions & 5 deletions kedro/io/csv_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ class CSVS3DataSet(AbstractDataSet, S3PathVersionMixIn):
>>> assert data.equals(reloaded)
"""

DEFAULT_SAVE_ARGS = {"index": False}

def _describe(self) -> Dict[str, Any]:
return dict(
filepath=self._filepath,
Expand Down Expand Up @@ -101,14 +103,10 @@ def __init__(
attribute is None, save version will be autogenerated.

"""
default_save_args = {"index": False}
self._save_args = (
{**default_save_args, **save_args} if save_args else default_save_args
)
self._load_args = load_args if load_args else {}
self._filepath = filepath
self._bucket_name = bucket_name
self._credentials = credentials if credentials else {}
super().__init__(load_args, save_args)
self._version = version
self._s3 = S3FileSystem(client_kwargs=self._credentials)

Expand Down
17 changes: 4 additions & 13 deletions kedro/io/excel_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ class ExcelLocalDataSet(AbstractDataSet, FilepathVersionMixIn):

"""

DEFAULT_LOAD_ARGS = {"engine": "xlrd"}
DEFAULT_SAVE_ARGS = {"index": False}

def _describe(self) -> Dict[str, Any]:
return dict(
filepath=self._filepath,
Expand Down Expand Up @@ -105,19 +108,7 @@ def __init__(

"""
self._filepath = filepath
default_save_args = {"index": False}
default_load_args = {"engine": "xlrd"}

self._load_args = (
{**default_load_args, **load_args}
if load_args is not None
else default_load_args
)
self._save_args = (
{**default_save_args, **save_args}
if save_args is not None
else default_save_args
)
super().__init__(load_args, save_args)
self._engine = engine
self._version = version

Expand Down
13 changes: 1 addition & 12 deletions kedro/io/hdf_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,20 +92,9 @@ def __init__(
attribute is None, save version will be autogenerated.

"""
default_load_args = {}
default_save_args = {}
self._filepath = filepath
self._key = key
self._load_args = (
{**default_load_args, **load_args}
if load_args is not None
else default_load_args
)
self._save_args = (
{**default_load_args, **save_args}
if save_args is not None
else default_save_args
)
super().__init__(load_args, save_args)
self._version = version

def _load(self) -> pd.DataFrame:
Expand Down
14 changes: 1 addition & 13 deletions kedro/io/hdf_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
HDFSTORE_DRIVER = "H5FD_CORE"


# pylint: disable=too-many-instance-attributes
class HDFS3DataSet(AbstractDataSet, S3PathVersionMixIn):
"""``HDFS3DataSet`` loads and saves data to a S3 bucket. The
underlying functionality is supported by pandas, so it supports all
Expand Down Expand Up @@ -100,22 +99,11 @@ def __init__(
attribute is None, save version will be autogenerated.

"""
default_load_args = {}
default_save_args = {}
self._filepath = filepath
self._key = key
self._bucket_name = bucket_name
self._credentials = credentials if credentials else {}
self._load_args = (
{**default_load_args, **load_args}
if load_args is not None
else default_load_args
)
self._save_args = (
{**default_load_args, **save_args}
if save_args is not None
else default_save_args
)
super().__init__(load_args, save_args)
self._version = version
self._s3 = S3FileSystem(client_kwargs=self._credentials)

Expand Down
15 changes: 3 additions & 12 deletions kedro/io/json_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ class JSONLocalDataSet(AbstractDataSet, FilepathVersionMixIn):

"""

DEFAULT_SAVE_ARGS = {"indent": 4}

def _describe(self) -> Dict[str, Any]:
return dict(
filepath=self._filepath,
Expand Down Expand Up @@ -90,19 +92,8 @@ def __init__(
attribute is None, save version will be autogenerated.

"""
default_save_args = {"indent": 4}
default_load_args = {}
self._filepath = filepath
self._load_args = (
{**default_load_args, **load_args}
if load_args is not None
else default_load_args
)
self._save_args = (
{**default_save_args, **save_args}
if save_args is not None
else default_save_args
)
super().__init__(load_args, save_args)
self._version = version

def _load(self) -> Any:
Expand Down
17 changes: 3 additions & 14 deletions kedro/io/parquet_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ class ParquetLocalDataSet(AbstractDataSet, FilepathVersionMixIn):
>>> assert data.equals(loaded_data)
"""

DEFAULT_SAVE_ARGS = {"compression": None}

def _describe(self) -> Dict[str, Any]:
return dict(
filepath=self._filepath,
Expand Down Expand Up @@ -107,22 +109,9 @@ def __init__(
attribute is None, save version will be autogenerated.

"""
default_save_args = {"compression": None}
default_load_args = {}

self._filepath = filepath
self._engine = engine

self._load_args = (
{**default_load_args, **load_args}
if load_args is not None
else default_load_args
)
self._save_args = (
{**default_save_args, **save_args}
if save_args is not None
else default_save_args
)
super().__init__(load_args, save_args)
self._version = version

def _load(self) -> pd.DataFrame:
Expand Down
14 changes: 1 addition & 13 deletions kedro/io/pickle_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,6 @@ def __init__(
ImportError: If 'backend' could not be imported.

"""
default_save_args = {}
default_load_args = {}

if backend not in ["pickle", "joblib"]:
raise ValueError(
"backend should be one of ['pickle', 'joblib'], got %s" % backend
Expand All @@ -128,16 +125,7 @@ def __init__(

self._filepath = filepath
self._backend = backend
self._load_args = (
{**default_load_args, **load_args}
if load_args is not None
else default_load_args
)
self._save_args = (
{**default_save_args, **save_args}
if save_args is not None
else default_save_args
)
super().__init__(load_args, save_args)
self._version = version

def _load(self) -> Any:
Expand Down
14 changes: 1 addition & 13 deletions kedro/io/pickle_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,23 +95,11 @@ def __init__(
None, the latest version will be loaded. If its ``save``
attribute is None, save version will be autogenerated.
"""
default_load_args = {}
default_save_args = {}

self._filepath = filepath
self._bucket_name = bucket_name
self._credentials = credentials if credentials else {}
super().__init__(load_args, save_args)
self._version = version
self._load_args = (
{**default_load_args, **load_args}
if load_args is not None
else default_load_args
)
self._save_args = (
{**default_save_args, **save_args}
if save_args is not None
else default_save_args
)
self._s3 = S3FileSystem(client_kwargs=self._credentials)

@property
Expand Down
16 changes: 3 additions & 13 deletions kedro/io/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,8 @@ class SQLTableDataSet(AbstractDataSet):

"""

DEFAULT_SAVE_ARGS = {"index": False}

def _describe(self) -> Dict[str, Any]:
load_args = self._load_args.copy()
save_args = self._save_args.copy()
Expand Down Expand Up @@ -193,19 +195,7 @@ def __init__(
"provide a SQLAlchemy connection string."
)

default_save_args = {"index": False}
default_load_args = {}

self._load_args = (
{**default_load_args, **load_args}
if load_args is not None
else default_load_args
)
self._save_args = (
{**default_save_args, **save_args}
if save_args is not None
else default_save_args
)
super().__init__(load_args, save_args)

self._load_args["table_name"] = table_name
self._save_args["name"] = table_name
Expand Down
Loading